importnumpy as npimportmatplotlib.pyplot as plt‘‘‘试验transpose()
def back (a,b):
return a,b
if __name__ == ‘__main__‘:
a = np.array([[1,2,3],[11,12,13],[21,22,23]])
print(a)
b = np.array([[31,32,33],[41,42,43],[51,52,53]])
print(b)
a, b = transpose(back(a,b))
#a, b = back(a, b)
print(a)
print(b)‘‘‘
#数据加载器基类
classLoader(object):def __init__(self, path, count):‘‘‘初始化加载器
path: 数据文件路径
count: 文件中的样本个数‘‘‘self.path=path
self.count=countdefget_file_content(self):‘‘‘读取文件内容‘‘‘f= open(self.path, ‘rb‘)
content=f.read()
f.close()returncontentdefto_int(self, byte):‘‘‘将unsigned byte字符转换为整数‘‘‘
#print(byte)
#return struct.unpack(‘B‘, byte)[0]
returnbyte#图像数据加载器
classImageLoader(Loader):defget_picture(self, content, index):‘‘‘内部函数,从文件中获取图像‘‘‘start= index * 28 * 28 + 16picture=[]for i in range(28):
picture.append([])for j in range(28):
picture[i].append(
self.to_int(content[start+ i * 28 +j]))returnpicturedefget_one_sample(self, picture):‘‘‘内部函数,将图像转化为样本的输入向量‘‘‘sample=[]for i in range(28):for j in range(28):
sample.append(picture[i][j])returnsampledefload(self):‘‘‘加载数据文件,获得全部样本的输入向量‘‘‘content=self.get_file_content()
data_set=[]for index inrange(self.count):
data_set.append(
self.get_one_sample(
self.get_picture(content, index)))returndata_set#标签数据加载器
classLabelLoader(Loader):defload(self):‘‘‘加载数据文件,获得全部样本的标签向量‘‘‘content=self.get_file_content()
labels=[]for index inrange(self.count):
labels.append(self.norm(content[index+ 8]))returnlabelsdefnorm(self, label):‘‘‘内部函数,将一个值转换为10维标签向量‘‘‘label_vec=[]
label_value=self.to_int(label)for i in range(10):if i ==label_value:
label_vec.append(0.9)else:
label_vec.append(0.1)returnlabel_vecdefget_training_data_set():‘‘‘获得训练数据集‘‘‘image_loader= ImageLoader(‘train-images.idx3-ubyte‘, 60000)
label_loader= LabelLoader(‘train-labels.idx1-ubyte‘, 60000)returnimage_loader.load(), label_loader.load()defget_test_data_set():‘‘‘获得测试数据集‘‘‘image_loader= ImageLoader(‘t10k-images.idx3-ubyte‘, 10000)
label_loader= LabelLoader(‘t10k-labels.idx1-ubyte‘, 10000)returnimage_loader.load(), label_loader.load()if __name__ == ‘__main__‘:
train_data_set, train_labels=get_training_data_set()
line=np.array(train_data_set[0])
img= line.reshape((28,28))
plt.imshow(img)
plt.show()