pytorch小记(二十):深入解析 PyTorch 的 `torch.randn_like`:原理、参数与实战示例
在深度学习模型中,我们经常需要在已有张量的基础上生成与之「同形状」「同设备」「同或不同数据类型」的随机噪声,用于参数扰动、数据增强、扩散模型等场景。PyTorch 为我们提供了一个高效便捷的工具——torch.randn_like
,它能一步完成上述需求。本文将从函数定义、参数详解、典型应用场景,到进阶用法,全面剖析 torch.randn_like
,并通过丰富示例帮助你快速上手。
一、函数签名与参数详解
torch.randn_like(
input: Tensor,
*,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
device: Optional[torch.device] = None,
requires_grad: bool = False,
memory_format: Optional[torch.memory_format] = None
) → Tensor
-
input
(必选)
源张量,randn_like
会读取它的.shape
、.dtype
、.device
、.layout
、以及memory_format
(如果未显式指定覆盖项)。 -
dtype
(可选)
生成张量的数据类型,如torch.float32
、torch.int64
等。若不指定,则继承input.dtype
。 -
device
(可选)
指定在 CPU 还是 GPU 上创建新张量,如"cpu"
、"cuda:0"
。若不指定,则继承input.device
。 -
requires_grad
(可选)
是否对新张量开启梯度追踪,默认为False
。 -
其他
layout
:张量内存布局,一般使用默认;memory_format
:指定内存格式,如torch.contiguous_format
。
二、torch.randn_like
vs torch.randn
方法 | 参数 | 优点 |
---|---|---|
torch.randn(size) | 必须手动传入 size 、可选传入 dtype 、device 等 | 简单直观,适合只关心形状的场景 |
torch.randn_like(input) | 自动继承 input.shape 、dtype 、device 、layout 等属性 | 减少样板代码,保证输出张量与输入环境一致 |
三、基础示例
import torch
# 1. 构造一个形状为 (2, 3) 的零张量
x = torch.zeros(2, 3)
print("x:", x.shape, x.dtype, x.device)
# x: torch.Size([2, 3]) torch.float32 cpu
# 2. 生成与 x 同形状同属性的标准正态随机张量
noise = torch.randn_like(x)
print("noise:", noise)
# 示例输出:
# tensor([[-0.1245, 0.5487, -0.3221],
# [ 0.8723, -1.0054, 0.0392]])
- 新张量
noise
与x
的形状、数据类型、设备保持一致。
四、进阶用法与参数覆盖
4.1 覆盖数据类型(dtype)
# 强制生成 float64 类型
noise_fp64 = torch.randn_like(x, dtype=torch.float64)
print(noise_fp64.dtype) # torch.float64
4.2 覆盖设备(device)
if torch.cuda.is_available():
noise_gpu = torch.randn_like(x, device=torch.device('cuda:0'))
print(noise_gpu.device) # cuda:0
4.3 开启梯度追踪(requires_grad)
noise_grad = torch.randn_like(x, requires_grad=True)
print(noise_grad.requires_grad) # True
4.4 覆盖内存格式(memory_format)
noise_contig = torch.randn_like(x, memory_format=torch.contiguous_format)
# 通常无需显式指定,除非对内存布局有特殊需求
五、典型应用场景
1. 给模型参数添加噪声
在对抗训练、参数平滑或元学习中,需要对权重做微小扰动:
import torch.nn as nn
class NoisyLinear(nn.Linear):
def forward(self, input):
# 为权重张量添加微小高斯噪声
weight_noise = torch.randn_like(self.weight) * 0.01
return nn.functional.linear(input, self.weight + weight_noise, self.bias)
layer = NoisyLinear(128, 64)
x = torch.randn(32, 128)
out = layer(x) # 前向过程中,自动生成同形状噪声
2. 数据增强:图像高斯噪声
对图像 Batch 注入随机噪声,提升模型鲁棒性:
# 假设 images 形状为 [B, C, H, W]
images = torch.randn(16, 3, 224, 224) # 示例输入
noise_std = 0.1
noisy_images = images + torch.randn_like(images) * noise_std
# 这样可以保证噪声形状 / dtype / device 与 images 完全一致
3. 扩散模型(DDPM)中的噪声采样
在扩散模型中,需要不断向数据添加标准正态噪声,且噪声张量形状与数据完全对齐:
def q_sample(x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
# 根据时间步 t 计算噪声比例等后续操作...
return x_start * alpha_t[t] + noise * beta_t[t]
六、多种等价写法
tensor.long()
、tensor.to(torch.int64)
tensor.type(torch.float32)
等方法,均可对已有张量做类型转换,与randn_like
结合时常用于进一步处理。
七、小结
- 功能:
torch.randn_like
快速生成与指定张量同形状、同设备的标准正态分布随机张量。 - 参数覆盖:可选
dtype
、device
、requires_grad
、memory_format
等,灵活适配各种需求。 - 典型场景:参数扰动、数据增强、扩散模型、随机索引等。
- 最佳实践:在不关心形状等属性细节时,用
randn_like
省去 boilerplate;在需要覆盖属性时,通过关键字参数一次性完成。