对数据进行封装:TensorDataset
torch.utils.data.Dataset 表示一个数据集的抽象类,所有的其它数据集都要以它为父类进行数据封装。
TensorDataset 继承自 Dataset,重载了__init__,getitem,len。
class TensorDataset(Dataset):
def __init__(self,data_tensor,target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
def __len__(self):
return self.data_tensor.size(0)
- data_tensor 是需要被封装的数据样本
- target_tensor 是需要被封装的数据标签
加载数据:DataLoader
torch.utils.data.DataLoader 结合了数据集和取样器,并且可以提供多个线程处理数据集。
在训练模型时候该类可以将数据进行切分,每次抛出一组数据,直至把所有的数据都抛出。
实例如下:
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
# 把数据放在数据库中
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
)
print(loader)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# training
print("epoch:{}, step:{}, batch_x:{}, batch_y:{}".format(epoch,step,batch_x,batch_y))
if __name__ == '__main__':
show_batch()
保存模型
直接保存整个模型并读取:
# 创建模型的示例对象:model
model = net()
# 保存模型
torch.save(model, 'model_name.pth')
# 读取模型
model = torch.load('model_name.pth')
只保存模型中的参数:
# 定义函数,保存最新和最佳模型
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
# 调用时:
save_checkpoint({
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
}, is_best, filename=os.path.join(args.save_dir, 'model.th'))
保留验证集上最好的模型
验证集的作用是在训练的过程中监测是否训练过度(过拟合)。
一般可以默认验证集的损失函数由下降转为上升(即最小值)处,模型的泛化能力最好。
min_loss_val = 10 # 任取一个大数
best_model = None
min_epoch = 100 # 训练至少需要的轮数
for epoch in range(args.epochs):
loss_val, loss_acc = train(epoch)
if epoch > min_epoch and loss_val <= min_loss_val:
min_loss_val = loss_val
best_model = copy.deepcopy(model)
model = best_model