pytorch 普通模型以及 BERT加载与保存

torch.save doc

在这里插入图片描述
主要用的就前两个参数

  • obj:要保存的python 对象

  • f:open出来的io文件,或者是只是保存文件路径,文件名的str(后者应该就是把这个str 以"w"方式open出来了)

注意obj这个对象必须要能够serialization(如果是你自己自定义的obj,要实现serialization).一般而言,想要自己定义的obf能够序列化,可以实现to_dict,to_json method:

class Serializable(object):

    def __init__(self):
    	pass

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

普通torch.nn.Module

1. save

个人一般喜欢用下面的方式保存torch的模型:

	@overrides
    def _save(self, save_path):
        encoder_path = os.path.join(save_path, "encoder")
        dense_path = os.path.join(save_path, "project_layers")
        if not os.path.exists(encoder_path):
            os.mkdir(encoder_path)
        if not os.path.exists(dense_path):
            os.mkdir(dense_path)
        # save LSTM
        torch.save({'state_dict': self._encoder.state_dict()}, os.path.join(encoder_path, "Bi-LSTM.pth.tar"))
        # save project_layers
        torch.save({'state_dict': self._project_layer.state_dict()}, os.path.join(dense_path, "dense.pth.tar"))

也就是torch.save的第一个参数可以直接给一个dict,包含一个model的state_dict字段,后面就是保存模型的文件名,我喜欢保存成tar包

2. load

如果按照上面的方法,用torch.save把模型state和param存成dict保存的话,load就是反过来

	@overrides
	def _load(self, save_path) -> bool:
	   encoder_path = os.path.join(save_path, "encoder")
	   dense_path = os.path.join(save_path, "project_layers")
	   if not os.path.exists(encoder_path):
	       return False
	   if not os.path.exists(dense_path):
	       return False
	   try:
	       # load LSTM
	       lstm_checkpoint = torch.load(os.path.join(encoder_path, "Bi-LSTM.pth.tar"))
	       self._encoder.load_state_dict(lstm_checkpoint['state_dict'])
	       # load project_layers
	       dense_checkpoint = torch.load(os.path.join(dense_path, "dense.pth.tar"))
	       self._project_layer.load_state_dict(dense_checkpoint['state_dict'])
	   except FileNotFoundError:
	       print("model checkpoints file missing: %s" % save_path)
	       return False
	   return True

先用torch.load把dict load过来,然后调用torch.nn.Module的类方法load_state_dict,把dict里面的state_dict传入(这里假定你的模型结构已经构造好了,或者模型还在ram里面,就可以直接load回来)
如果load成功,也就是你保存的state_dict和模型的结构是吻合的,那么应该会有类似于All key correct之类的字样

transformer BERT

具体可以参见huggingface的doc

1. save

	@overrides
    def _save(self, save_path):
        encoder_path = os.path.join(save_path, "encoder")
        dense_path = os.path.join(save_path, "project_layers")
        if not os.path.exists(encoder_path):
            os.mkdir(encoder_path)
        if not os.path.exists(dense_path):
            os.mkdir(dense_path)
        # save BERT weight,config and tokenizer
        model_to_save = (self._encoder.module if hasattr(self._encoder, "module") else self._encoder)
        model_to_save.save_pretrained(encoder_path)
        self._tokenizer.save_pretrained(encoder_path)
        # save project_layers
        torch.save({'state_dict': self._project_layer.state_dict()}, os.path.join(dense_path, "dense.pth.tar"))

这里model_to_save的目的是为了防止并行训练的时候冲突,一般而言不用写,放这里也不会有错。
如果你的模型是transformer的PreTrainedModel (同时也会是torch.nn.Module子类),那么就可以把你的模型直接用save_pretrained保存到指定路径。
熟悉huggingface框架的人都清楚pretrained bert一般需要三样东西:config,tokenizer,model.bin,
model.save_pretrained其实就是保存了模型参数model.bin以及config json文件,同时再把它配套的tokenizer也保存到相同的路径下,就会把vocab之类的也保存。

2. load

    @overrides
    def _load(self, save_path) -> bool:
        encoder_path = os.path.join(save_path, "encoder")
        dense_path = os.path.join(save_path, "project_layers")
        if not os.path.exists(encoder_path):
            return False
        if not os.path.exists(dense_path):
            return False
        try:
            # load BERT weight,config and tokenizer
            self._encoder = AutoModel.from_pretrained(encoder_path)
            self._tokenizer = AutoTokenizer.from_pretrained(encoder_path)
            self._encoder.cuda()
            # load project_layers
            dense_checkpoint = torch.load(os.path.join(dense_path, "dense.pth.tar"))
            self._project_layer.load_state_dict(dense_checkpoint['state_dict'])
        except FileNotFoundError:
            print("model checkpoints file missing: %s" % save_path)
            return False
        return True

可以直接新建AutoClass,然后from_pretrained就可以了,不管你保存的是什么模型,只要路径下所有信息都是完整的,那么Autoclass就会确定你的model种类(BERT base/large Roberta base/large).
所以先把model load过来,这个过程中先会把config和model.bin,也就是模型的结构和参数全部重新设置好,然后tokenizer同样

当然也可以先Autoconfig建立好,然后把config当作参数传入model.from_pretrained,但是没必要,因为之所以transformer在使用的时候一般一开始需要独立地新建一个Autoconfig是因为你可以在config里面自己额外再定义一些东西,但是这里是自己训练的模型所以不需要。换句话说,如果你只是想要用最bare的bert,那么直接AutoModel.from_pretrained就可以了;如果你想要用BertModelForSequenceClassification之类的model,那么你需要新建一个Autoconfig在里面指定num_class(不然怎么知道顶层dense的outdim),再from_pretrained(config=)将config传入

  • 7
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值