vutils.save_image
是 PyTorch 的 torchvision.utils
模块中的一个函数,用于保存张量(tensor)或者一批张量(batch of tensors)为图像文件。该函数对于可视化和保存模型生成的图像特别有用。
下面是该函数的一些详细信息和常见的用法:
函数签名
torchvision.utils.save_image(tensor, fp, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0, format=None)
参数
tensor
(Tensor or list of Tensor): 要保存的图像数据。如果是4D张量,假设格式为n x c x h x w
,其中n
是图像的数量,c
是通道数,h
是高度,w
是宽度。fp
(str or pathlib.Path object or file object): 一个文件名或一个打开的文件对象,用于保存图像。nrow
(int, optional): 每行有多少张图像,仅当保存多张图像为网格布局时有用。padding
(int, optional): 图像间的填充大小。normalize
(bool, optional): 是否将图像数据标准化。如果为True
,根据range
参数中给出的最小和最大值,将张量的每个通道标准化到 ([0, 1])。range
(tuple, optional):(min, max)
形式的元组,指定标准化的范围。min
和max
是张量应该标准化的范围,仅当normalize=True
时有效。scale_each
(bool, optional): 是否独立地对每张图像进行标准化(而不是基于所有图像的最小值和最大值)。pad_value
(float, optional): 填充的像素值。format
(str, optional): 图像文件的格式(如 ‘jpeg’、‘png’)。如果未指定,将从文件名扩展名推断出格式。
常见用法
import torchvision.utils as vutils
import torch
# 假设你有一个张量表示一批图像
images = torch.randn(64, 3, 32, 32) # 64张3通道的32x32图像
# 保存单个图像
vutils.save_image(images[0], 'single_image.png')
# 保存一批图像为网格布局
vutils.save_image(images, 'image_grid.png', nrow=8, normalize=True)
# 保存图像时标准化并指定填充和图像格式
vutils.save_image(images, 'image_grid.jpg', nrow=8, padding=4, normalize=True, range=(-1, 1), format='JPEG')
在实际应用中,你可能会在训练过程中使用 vutils.save_image
来保存生成的样本,以便监控模型的进展。通过 normalize
和 range
参数,你可以确保即使图像的原始像素值不在标准的图像表示范围内,保存的图像也能正确显示。