一 x.clone()
错误:
import torch
x = torch.randn(5)
print(x)
y = x.clone()
x =x*12
print(y)
tensor([-0.0455, 0.9827, -0.7915, 0.8110, 1.1524])
tensor([-0.0455, 0.9827, -0.7915, 0.8110, 1.1524])
正确:
import torch
x = torch.randn(5)
print(x)
y = x.clone()
z =x*12
print(z)
tensor([ 1.7743, -2.0879, 0.4297, -0.3849, -0.6147])
tensor([ 21.2920, -25.0553, 5.1562, -4.6190, -7.3768])
class A(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.Identity() #nn.Linear(12, 12)
def forward(self, x):
print(f'old x:{x}')
y = x #==> .clone()
x = self.m(x)*12
print(f'new x:{x}')
print(f'res y:{y}')
return
A()(torch.randn(12))
old x:tensor([-0.0204, 0.5557, 2.1049, -1.2744, -1.0640, 1.2480, -2.2674, -0.6850,
0.0507, -0.9022, -0.2756, -0.1572])
new x:tensor([ -0.2450, 6.6683, 25.2586, -15.2927, -12.7677, 14.9765, -27.2085,
-8.2195, 0.6080, -10.8261, -3.3074, -1.8865])
res y:tensor([-0.0204, 0.5557, 2.1049, -1.2744, -1.0640, 1.2480, -2.2674, -0.6850,
0.0507, -0.9022, -0.2756, -0.1572])