transformers自定义模型的保存和加载

本文介绍如何使用PyTorch保存和加载自定义的PLBART模型参数,并提供了一个独立的模型文件示例,便于管理和复用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

step1 保存 (my_plbart.py)

#如果一开始用了并行训练最好加上这句
model_to_save = model.module if hasattr(model, 'module') else model
#这样保存的是模型参数,记得格式是.pt
torch.save(model_to_save.state_dict(),output_model_dir+"model-2.pt")

step2 加载 (use_plbart.py)

#因为是自定义模型呀
model = Model()
#拿到保存的参数
model_static_dict = torch.load(output_model_dir+"model-2.pt")
#把参数加载到模型中
model.load_state_dict(model_static_dict)

注意:

两个文件中的 output_model_dir 路径和Model类应该是一致的。

话外:

如果你的模型不是自定义的,而是直接用的transformers中from_pretrained得到的,那么可以直接用save_pretrained进行保存。以上提供的是更一般化的方法,即torch对模型参数保存和加载的支持。

附上完整的模型文件 only_model.py

import torch
from transformers import PLBartConfig, PLBartModel, PLBartTokenizer

plbart_hf_path = "uclanlp/plbart-multi_task-java"
plbart_local_path = "your_path/plbart_files"
output_model_dir = 'your_path/PLBART_huggingface/finetuned_models/'


checkpoint = plbart_local_path

myTokenizer = PLBartTokenizer.from_pretrained(checkpoint)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.pretrained = PLBartModel.from_pretrained(checkpoint)
        # 定义一组值全为0的常量
        self.register_buffer(
            "final_logits_bias",
            torch.zeros(1, myTokenizer.vocab_size)
        )
        self.fc = torch.nn.Linear(768, myTokenizer.vocab_size, bias=False)
        # 加载预训练模型的参数
        parameters = PLBartConfig()
        # self.fc.load_state_dict(parameters.lm_head.state_dict())

        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels, decoder_input_ids):
        logits = self.pretrained(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids
        )
        logits = logits.last_hidden_state
        logits = self.fc(logits)+self.final_logits_bias
        loss = self.criterion(logits.flatten(end_dim=1), labels.flatten())
        return {"loss": loss, "logits": logits}

(only_model.py被其他两个py引用,单拎出来形成一个模型文件的好处是,如果直接用use_plbart.py引用my_plbart.py,还会引用进很多无关的代码,Maybe非常耗时甚至卡住)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CSU迦叶

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值