半监督学习文本分类(六)

2021SC@SDUSC
文本分类半监督学习(六)
2021SC@SDUSC​​​​​​​

分析完对于数据的加载以后,来进行bert模型的加载与训练
首先我们来分析normal_bert.py文件,这个文件中是实现bert基线模型的代码
首先对其构造函数进行分析,我们可以看到,构造函数中,运用到了BertModel.from_pretrained(‘bert-base-uncased’,force_download=True),这个其实是通过网络进行下载bert-base-uncased基线模型。
在这里,其实我对于作者给出的代码进行了修改,因为原始代码在运行时候,会报出RuntimeError: unexpected EOF, expected 76257237 more bytes. The file might be corrupted.的错误,因此,我对于代码进行了修改:在函数后面添加了参数force_download=True,之后‘bert-base-uncased’模型才正常地下载下来。

class ClassificationBert(nn.Module):
    def __init__(self, num_labels=2):
        super(ClassificationBert, self).__init__()
        # Load pre-trained bert pip install pytorch-pretrained-bertmodel
        self.bert = BertModel.from_pretrained('bert-base-uncased',force_download=True)
        self.linear = nn.Sequential(nn.Linear(768, 128),
                                    nn.Tanh(),
                                    nn.Linear(128, num_labels))

下图为此模型下载的截图:
代码运行的图片

其次我们对其中具体方法foward()进行分析:
可以看到,首先在 all_hidden, pooler = self.bert(x)以及pooled_output = torch.mean(all_hidden, 1)中对于文本进行编码,其次利用linear函数进行预测

def forward(self, x, length=256):
    all_hidden, pooler = self.bert(x)
    pooled_output = torch.mean(all_hidden, 1)
    predict = self.linear(pooled_output)

    return predict

下面,我们进入normal_train.py文件分析训练bert模型的过程:
首先,我们在运行代码过程中,发现了一个关键性的错误,在输入数据集时,作者给出的path只是一个象征,而需要我们自己给出数据集的真实路径,额外的,作者在自述文件中也阐明了,利用代码对数据集进行清洗过慢,其已给出详细的完整数据集地址,仅需自行下载并修改地址即可,下图为我自己下载的数据集展示:
数据集
下方代码为我自己修改的代码:

parser.add_argument('--data-path', type=str, default='/Users/wuzehao/Desktop/科研/文本分类/MixText-master/data/yahoo_answers_csv/',
                    help='path to data folders')

再者,进入详细的normal_train.py文件的分析,可见第八篇博客。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值