解决前
self.fc2.weight = nn.Parameter(self.fc2.weight * self.zeros)
解决后
device = torch.device(‘cuda:0’) # 假如我使用的GPU为cuda:0
self.fc2.weight = nn.Parameter(self.fc2.weight.to(device) * self.zeros.to(device)).to(device)
解决前
self.fc2.weight = nn.Parameter(self.fc2.weight * self.zeros)
解决后
device = torch.device(‘cuda:0’) # 假如我使用的GPU为cuda:0
self.fc2.weight = nn.Parameter(self.fc2.weight.to(device) * self.zeros.to(device)).to(device)