单GPU保存与加载:
模型保存: ####方法一 state = {'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict()} torch.save(state, 'model_path') ####方法二 torch.save(self.model.state_dict(), 'model_path') # 当前目录 ####################################################################################### 模型加载: ####方法一 self.model.load_state_dict(torch.load("model_path")['model']) ####方法二 self.model.load_state_dict(torch.load('model_path'))
DDP模式下模型保存与加载:
模型保存: 无module形式 ####方法一 state = {'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict()} torch.save(state, 'model_path') ####方法二 torch.save(self.model.state_dict(), 'model_path') module形式---建议模式 ####方法一 state = {'epoch': epoch, 'model': model.module.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict()} torch.save(state, 'model_path') ####方法二 torch.save(self.model.module.state_dict(), 'model_path') ####################################################################################### 模型加载: 未用含"moduel"方式保存, 导致缺失关键“key”:Missing key(s) in state_dict ############### 方法 1: add model = torch.nn.DataParallel(model) # 加上module model.load_state_dict(torch.load("model_path")) ############### 方法 2: remove model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load("model_path").items()}) ############### 方法 3: remove from collections import OrderedDict state_dict = torch.load("model_path") new_state_dict = OrderedDict() # create new OrderedDict that does not contain `module.` for k, v in state_dict.items(): name = k.replace('module.', '') new_state_dict[name] = v model.load_state_dict(new_state_dict) 含"moduel"方式保存 ####方法一 self.model.load_state_dict(torch.load("model_path")['model']) ####方法二 self.model.load_state_dict(torch.load('model_path))
DDP训练流程:
初始化
#初始化使用nccl后端(这个),当然还有别的后端,可以查看官方文档,介绍的比较清晰 torch.distributed.init_process_group(backend="nccl")
使用DistributedSampler
DDP并不会自动shard数据如果自己写数据流,得根据
torch.distributed.get_rank()
去shard数据,获取自己应用的一份如果用Dataset API,则需要在定义Dataloader的时候用DistributedSampler
去shard:分布式训练
model=torch.nn.parallel.DistributedDataParallel(model)
- 参考文章https://zhuanlan.zhihu.com/p/95700549
https://zhuanlan.zhihu.com/p/95700549
https://zhuanlan.zhihu.com/p/145427849
https://zhuanlan.zhihu.com/p/145427849
DDP训练问题:
1.自定义的模型结构,继承自nn.Module,除__init__(), forward()等重写方法外,在模型结构类内部自定义的一些函数如loss()等,在DDP训练时调用方式需注意
model = YoulModel()
model.loss() ### 找不到属性loss#在使用net = torch.nn.DDP(net)之后,原来的net会被封装为新的net的module属性里
model.module.loss()