报错代码
def train(n_epochs, model, train_x, train_y, test_x, max_cycle_t, y_test):
rmse_history = []
for epoch in range(1, n_epochs + 1):
model.train()
x_train, y_train = Variable(train_x), Variable(train_y)
optimizer.zero_grad()
output_train = model(x_train)
loss_train = criterion(output_train, y_train)
loss_train.backward()
optimizer.step()
应该是y_train标签类型有错误,找到dataset_prepare文件,在后面加上.long()
train_y = train_y.astype(int)
train_y = torch.from_numpy(train_y).long()
问题解决。