mnist手写体数据集是人工智能中最简单, 数据集下载的连接是:
fashion-mnist数据集的存储和mnist数据集的存储形式一样,数据集下载的连接是:
- 程序讲解:程序分成两个部分
- 首先使用load_mnist() 是加载mnist数据集或者是fashion-mnist数据集,两者格式完全相同,所以建议是将其保存到不同的文件夹下,通过指定文件夹选择加载的数据集
- 其次for循环,加载保存图像,并打印对应的标签
- 最后save_images按照框架保存图像,记得将图像的数值范围进行修改
def load_mnist():
# 2019 可以选择不同的数据集
# data_dir = "../Dataset/fashion-mnist/"
data_dir = "../Dataset/mnist_data/"
def extract_data(filename, num_data, head_size, data_size):
with gzip.open(filename) as bytestream:
bytestream.read(head_size)
buf = bytestream.read(data_size * num_data)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
return data
data = extract_data(data_dir + 'train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
trX = data.reshape((60000, 28, 28, 1))
data = extract_data(data_dir + 'train-labels-idx1-ubyte.gz', 60000, 8, 1)
trY = data.reshape((60000))
data = extract_data(data_dir + 't10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
teX = data.reshape((10000, 28, 28, 1))
data = extract_data(data_dir + 't10k-labels-idx1-ubyte.gz', 10000, 8, 1)
teY = data.reshape((10000))
trY = np.asarray(trY)
teY = np.asarray(teY)
X = np.concatenate((trX, teX), axis=0)
y = np.concatenate((trY, teY), axis=0).astype(np.int)
data_index = np.arange(X.shape[0])
print("*****************dataX**************", len(X))
np.random.shuffle(data_index)
# data_index = data_index[:128]
X = X[data_index, :, :, :]
y = y[data_index]
y_vec = np.zeros((len(y), 10), dtype=np.float)
for i, label in enumerate(y):
y_vec[i, y[i]] = 1.0
return X / 255., y_vec
def merge(images):
size = [8,8]
if isinstance(images, list):
images = np.array(images)
h, w = images.shape[1], images.shape[2]
img = np.zeros((h * size[0], w * size[1], 3))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j * h: j * h + h, i * w: i * w + w, :] = image
return img
def save_images(images, image_path):
change_image_formal = (images+1.)/2
image = np.squeeze(merge(images))
save_image = scipy.misc.imsave(path, image)
return save_image
data_X, data_y = load_mnist(self.dataset_name)
# print("self.data_X, self.data_y",self.data_X,self.data_y)
result_dir = "mnist"
model_name = "image-2-image"
#测试加载的数据集和标签是否对应 以测试成功
for idx in range(5):
batch_size = 64
batch_images = data_X[idx * batch_size:(idx + 1) * batch_size]
# 2019 2 3不执行
batch_images_y = data_y[idx * batch_size:(idx + 1) * batch_size]
manifold_h = int(np.floor(np.sqrt(batch_size)))
manifold_w = int(np.floor(np.sqrt(batch_size)))
save_images(batch_images[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
'./' + result_dir + '/' + model_name + '_real_image_{:04d}.png'.format(
idx))
print("batch_images_y的数值是:", batch_images_y)
结果展示: