importnumpy as npimportmatplotlib.pyplot as plt‘‘‘试验transpose()

def back (a,b):

return a,b

if __name__ == ‘__main__‘:

a = np.array([[1,2,3],[11,12,13],[21,22,23]])

print(a)

b = np.array([[31,32,33],[41,42,43],[51,52,53]])

print(b)

a, b = transpose(back(a,b))

#a, b = back(a, b)

print(a)

print(b)‘‘‘

#数据加载器基类

classLoader(object):def __init__(self, path, count):‘‘‘初始化加载器

path: 数据文件路径

count: 文件中的样本个数‘‘‘self.path=path

self.count=countdefget_file_content(self):‘‘‘读取文件内容‘‘‘f= open(self.path, ‘rb‘)

content=f.read()

f.close()returncontentdefto_int(self, byte):‘‘‘将unsigned byte字符转换为整数‘‘‘

#print(byte)

#return struct.unpack(‘B‘, byte)[0]

returnbyte#图像数据加载器

classImageLoader(Loader):defget_picture(self, content, index):‘‘‘内部函数,从文件中获取图像‘‘‘start= index * 28 * 28 + 16picture=[]for i in range(28):

picture.append([])for j in range(28):

picture[i].append(

self.to_int(content[start+ i * 28 +j]))returnpicturedefget_one_sample(self, picture):‘‘‘内部函数,将图像转化为样本的输入向量‘‘‘sample=[]for i in range(28):for j in range(28):

sample.append(picture[i][j])returnsampledefload(self):‘‘‘加载数据文件,获得全部样本的输入向量‘‘‘content=self.get_file_content()

data_set=[]for index inrange(self.count):

data_set.append(

self.get_one_sample(

self.get_picture(content, index)))returndata_set#标签数据加载器

classLabelLoader(Loader):defload(self):‘‘‘加载数据文件,获得全部样本的标签向量‘‘‘content=self.get_file_content()

labels=[]for index inrange(self.count):

labels.append(self.norm(content[index+ 8]))returnlabelsdefnorm(self, label):‘‘‘内部函数,将一个值转换为10维标签向量‘‘‘label_vec=[]

label_value=self.to_int(label)for i in range(10):if i ==label_value:

label_vec.append(0.9)else:

label_vec.append(0.1)returnlabel_vecdefget_training_data_set():‘‘‘获得训练数据集‘‘‘image_loader= ImageLoader(‘train-images.idx3-ubyte‘, 60000)

label_loader= LabelLoader(‘train-labels.idx1-ubyte‘, 60000)returnimage_loader.load(), label_loader.load()defget_test_data_set():‘‘‘获得测试数据集‘‘‘image_loader= ImageLoader(‘t10k-images.idx3-ubyte‘, 10000)

label_loader= LabelLoader(‘t10k-labels.idx1-ubyte‘, 10000)returnimage_loader.load(), label_loader.load()if __name__ == ‘__main__‘:

train_data_set, train_labels=get_training_data_set()

line=np.array(train_data_set[0])

img= line.reshape((28,28))

plt.imshow(img)

plt.show()

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐