对于pytorch恢复一个epoch中的中断的训练时,通常dataloader都会从头加载,对于大型数据集不友好,loss又重新下降了
这时候可以自定义sampler
import random
from torch.utils.data.dataloader import Sampler
random.seed(224) # use a fixed number
class MySampler(Sampler):
def __init__(self, data, i=0):
random.shuffle(data)#自定义shuffle
self.seq = list(range(len(data)))[i * batch_size:]
def __iter__(self):
return iter(self.seq)
def __len__(self):
return len(self.seq)
调用dataloader时传入自定义sampler,指定恢复的step
train_dataset = MyDataset(train_data)
train_sampler = MySampler(train_dataset, last_i)
train_data_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
sampler=train_sampler,
shuffle=False) # don't forget to set DataLoader's shuffle to False
就可以啦!ref
也可以用笨方法,空跑到指定的step:
for batch in train_loader:
if restart_step<global_step:
restart_step+=1
pbar.update(1)
continue