保存dataloader状态 恢复中断训练

当使用PyTorch进行大型数据集训练且希望从中断点恢复时,标准数据加载器会从头开始。为解决此问题,可以自定义Sampler类,如`MySampler`,在初始化时根据上次保存的步数跳过相应数据。通过设置`shuffle=False`并传入自定义sampler到DataLoader,能够继续训练而不会重新加载所有数据。另一种简单方法是手动空跑至恢复的步数,但效率较低。
摘要由CSDN通过智能技术生成

对于pytorch恢复一个epoch中的中断的训练时,通常dataloader都会从头加载,对于大型数据集不友好,loss又重新下降了

这时候可以自定义sampler

import random
from torch.utils.data.dataloader import Sampler


random.seed(224)  # use a fixed number


class MySampler(Sampler):
    def __init__(self, data, i=0):
        random.shuffle(data)#自定义shuffle
        self.seq = list(range(len(data)))[i * batch_size:]

    def __iter__(self):
        return iter(self.seq)

    def __len__(self):
        return len(self.seq)

调用dataloader时传入自定义sampler,指定恢复的step

train_dataset = MyDataset(train_data)
train_sampler = MySampler(train_dataset, last_i)
train_data_loader = DataLoader(dataset=train_dataset,                                                         
                               batch_size=batch_size, 
                               sampler=train_sampler,
                               shuffle=False)  # don't forget to set DataLoader's shuffle to False

就可以啦!ref
也可以用笨方法,空跑到指定的step:

for batch in train_loader:
  if restart_step<global_step:
     restart_step+=1
     pbar.update(1)
     continue
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值