这两个地方该成
hist_loss.append(loss.data)
print(‘[epoch %d/%d] [iter %d/%d] loss %f acc %f %f/batch’ % (epoch, num_epoch,
idx_train, len(train_file) // batch_size, loss.data,
correct / batch_size, dtime))
这些函数钱前面加上staticmethod
class LogFunction_v2(Function):
@staticmethod
def forward(self, input):
Us = torch.zeros_like(input)
Ss = torch.zeros((input.shape[0], input.shape[1])).double()
logSs = torch.zeros_like(input)
invSs = torch.zeros_like(input)
for i in range(input.shape[0]):
U, S, V = torch.svd(input[i, :