【PyTorch】数据封装和模型保存

对数据进行封装:TensorDataset

torch.utils.data.Dataset 表示一个数据集的抽象类,所有的其它数据集都要以它为父类进行数据封装。
TensorDataset 继承自 Dataset,重载了__init__,getitem,len

class TensorDataset(Dataset):
    def __init__(self,data_tensor,target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]
    def __len__(self):
        return self.data_tensor.size(0)
  • data_tensor 是需要被封装的数据样本
  • target_tensor 是需要被封装的数据标签

加载数据:DataLoader

torch.utils.data.DataLoader 结合了数据集和取样器,并且可以提供多个线程处理数据集。
在训练模型时候该类可以将数据进行切分,每次抛出一组数据,直至把所有的数据都抛出。
实例如下:

import torch
import torch.utils.data as Data

BATCH_SIZE = 5
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
# 把数据放在数据库中
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
)
print(loader)

在这里插入图片描述

def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            # training
            print("epoch:{}, step:{}, batch_x:{}, batch_y:{}".format(epoch,step,batch_x,batch_y))

if __name__ == '__main__':
    show_batch()

在这里插入图片描述

保存模型

直接保存整个模型并读取:

# 创建模型的示例对象:model
model = net()

# 保存模型
torch.save(model, 'model_name.pth')
# 读取模型
model = torch.load('model_name.pth')

只保存模型中的参数:

# 定义函数,保存最新和最佳模型
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

# 调用时:
save_checkpoint({
    'state_dict': model.state_dict(),
    'best_prec1': best_prec1,
}, is_best, filename=os.path.join(args.save_dir, 'model.th'))

保留验证集上最好的模型

验证集的作用是在训练的过程中监测是否训练过度(过拟合)。
一般可以默认验证集的损失函数由下降转为上升(即最小值)处,模型的泛化能力最好。

min_loss_val = 10 # 任取一个大数
best_model = None
min_epoch = 100 # 训练至少需要的轮数
for epoch in range(args.epochs):
    loss_val, loss_acc = train(epoch)
    if epoch > min_epoch and loss_val <= min_loss_val:
        min_loss_val = loss_val
        best_model = copy.deepcopy(model)
model = best_model
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值