训练数据最好使用DataSet和DataLoader,可以自动一个batch一个batch的取出数据,同时可以使用shuffle来打乱源数据顺序,防止因为顺序对模型造成一些不好的影响。
from torch.utils.data import DataLoader
HR_ds = TensorDataset(X, Y)
HR_dl = DataLoader(HR_ds, batch_size=batch)
model, opt = get_model()
for epoch in range(epochs):
for x, y in HR_dl:
y_pred = model(x)
loss = loss_fn(y_pred, y)
opt.zero_grad()
loss.backward()
opt.step()
print('epoch:', epoch, ' ', 'loss:', loss_fn(model(X), Y))