pytorch小记(六):pytorch中的clone和detach操作:克隆/复制数据 vs 共享相同数据但 与计算图断开联系
以下代码片段:
self.x = x.clone().detach() # 或 torch.tensor(x).float()
用于处理和复制张量 x
,并根据需要使其与原始计算图断开联系或改变其数据类型。下面是逐部分详细解释。
1. x.clone()
- 作用:对张量
x
进行深拷贝,生成一个新的张量。- 新的张量和原始张量具有相同的数据,但存储在不同的内存空间。
- 修改
clone()
的返回值不会影响原始张量。
示例:
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x.clone()
y[0] = 99.0
print(x) # tensor([1., 2., 3.], grad_fn=<CloneBackward>)
print(y) # tensor([99., 2., 3.])
2. x.detach()
- 作用:返回一个与
x
共享相同数据但 与计算图断开联系 的张量。- 通常用于阻止梯度计算。
- 在神经网络中,如果你不希望某些操作影响反向传播时,会用到
detach()
。