以下代码需要根据实际情况修改.transpose
、.reshape
和.permute
的参数
#@title Visualization code
# 将一个batchsize的图像按照网格形状展开
def image_grid(x):
size = config.data.image_size
channels = config.data.num_channels
img = x.reshape(-1, size, size, channels)
w = int(np.sqrt(img.shape[0])) # 方格的宽度
img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels))
return img
def show_samples(x):
size = config.data.image_size
channels = config.data.num_channels
w = int(np.sqrt(x.reshape(-1, size, size, channels).shape[0]))
x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
img = image_grid(x)
plt.figure(figsize=(w,w))
plt.axis('off')
plt.imshow(img)
plt.show()
或者直接用torchvision自带的
from torchvision.utils import make_grid, save_image
nrow = int(np.sqrt(samples_raw.shape[0])) # sample_raw为采集的样本,shape[0]为batchsize
image_grid = make_grid(samples_raw, nrow, padding=2)
with tf.io.gfile.GFile(os.path.join(this_sample_dir, "sample.png"), "wb") as fout:
save_image(image_grid, fout)
with open("/tmp/x", "w") as f: f.write("asdf")