mnist可视化
最近在研究GAN,数据集采用的是mnist数据集,为了将生成的图片与原数据进行对比,需要将mnist数据集可视化,具体代码如下,每张图片有5*5个手写数字图片。
from keras.datasets import mnist
import matplotlib.pyplot as plt
(X_train,_),(_,_)=mnist.load_data()
# X_train.shape=(60000,28,28)
r,c=5,5
w,h=28,28
batch=r*c
totalNum=X_train.shape[0]
fig,axs=plt.subplots(r,c)
for i in range(0,totalNum,batch):
print("Print %d figure"%i)
cur=X_train[i:i+batch]
cnt=0
for j in range(r):
for k in range(c):
axs[j,k].imshow(cur[cnt].reshape(w,h),cmap='gray')
axs[j,k].axis('off')
cnt+=1
fig.savefig("images/mnist_%d.jpg" % i)
采用keras加载mnist数据集,效果如下: