(2)Pytorch保存模型权重参数

保存模型权重参数

我们可以选择只保存模型权重参数,或者保存模型结构+权重参数,通常采用前者。此处介绍只保存模型权重的方法

1、只保存模型权重参数

# dir = 'xxxx/resnet18.pth'
import torch
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)

# 保存
torch.save(resnet18.state_dict(), 'xxxx/resnet18.pth')
# 调用
resnet18 = models.resnet18() 
resnet18.load_state_dict(torch.load('xxxx/resnet18.pth'))

2、保存模型权重、优化器权重、epoch信息

dir = 'mymodel.pth'
state = {'net':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
torch.save(state, dir) # 权重参数包括了模型权重、优化器权重、epoch
checkpoint = torch.load(dir) # checkpoint 把之前save的state加载进来
model.load_state_dict(checkpoint['net'])

optimizer = torch.optim.Adam(model.parameters()) # 定义optimizer
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1

补充 optimizer 优化器的定义资料
PyTorch优化算法:optimizer=torch.optim.Adam参数介绍

train_dataset = Real('./data/SIDD_train/', 320, args.ps) + Syn('./data/Syn_train/', 100, args.ps)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs, shuffle=True, num_workers=8, pin_memory=True, drop_last=True)

dataloader 代码撰写

# 使用的代码
dataset_train = MyTrainDataSet(input_path, label_path)


from torch.utils.data import  Dataset
class MyTrainDataSet(Dataset):
    def __init__(self, input_path, label_path):  ## 创建实例时调用
        super(MyTrainDataSet, self).__init__()
        self.input_path = input_path
        self.input_images = os.listdir(input_path)

        self.label_path = label_path
        self.target_images = os.listdir(label_path)

        self.transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),]) # 就用256的大小进行训练!

    def __len__(self):
        return len(self.input_images)
    
    def __getitem__(self, index):  ## 
        input_image_path = os.path.join(self.input_path, self.input_images[index])
        input_image = Image.open(input_image_path).convert('RGB')

        label_image_path = os.path.join(self.label_path, self.target_images[index])
        label_image = Image.open(label_image_path).convert('RGB')

        input = self.transforms(input_image)
        label = self.transforms(label_image)

        return input, label

理论补充
1、

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值