Motivation: Visualizing the 3-D outputs
import torch
import torchvision
import matplotlib.pyplot as plt
image = torch.randn(1, 16, 256, 256)#(NDHW)
image = image.permute(1,0,2,3)#torch.Size([16, 1, 256, 256]) #viewed as (NCHW) only for the following visulization
img = torchvision.utils.make_grid(image, nrow=4,normalize=True).permute(1, 2, 0)
plt.imshow(img)
plt.show()
If we don’t have the code of
.permute(1, 2, 0)
we may get the error of TypeError: Invalid dimensions for image data
Reason: We need to permute the image tensor so that the channels are stored in dim2.