- 假设已经有了一组numpy 数组;此处框架为PyTorch
x=np.ones((10000,3))
x=torch.Tensor(x).float()
dataset=TensorDataset(x)
dataloader=DataLoader(dataset, batch_size=50,shuffle=True,num_workers=0, drop_last=True)
epochs=500
for epoch in range(epochs):
for batch_idx, data in enumerate (dataloader):
net.zero_grap()
data=torch.tensor([item.cpu().detach().numpy() for item in data]).cuda()
data=data.to(torch.device("cuda"))
loss=my_criterion(data)
loss.backward()
optimizer.step()
if batch_idx %100==0:
torch.save(net.state_dict(),path+“_epoch”+str(epoch)+"myy.pt")