2021SC@SDUSC
文本分类半监督学习(六)
2021SC@SDUSC
分析完对于数据的加载以后,来进行bert模型的加载与训练
首先我们来分析normal_bert.py文件,这个文件中是实现bert基线模型的代码
首先对其构造函数进行分析,我们可以看到,构造函数中,运用到了BertModel.from_pretrained(‘bert-base-uncased’,force_download=True),这个其实是通过网络进行下载bert-base-uncased基线模型。
在这里,其实我对于作者给出的代码进行了修改,因为原始代码在运行时候,会报出RuntimeError: unexpected EOF, expected 76257237 more bytes. The file might be corrupted.的错误,因此,我对于代码进行了修改:在函数后面添加了参数force_download=True,之后‘bert-base-uncased’模型才正常地下载下来。
class ClassificationBert(nn.Module):
def __init__(self, num_labels=2):
super(ClassificationBert, self).__init__()
# Load pre-trained bert pip install pytorch-pretrained-bertmodel
self.bert = BertModel.from_pretrained('bert-base-uncased',force_download=True)
self.linear = nn.Sequential(nn.Linear(768, 128),
nn.Tanh(),
nn.Linear(128, num_labels))
下图为此模型下载的截图:
其次我们对其中具体方法foward()进行分析:
可以看到,首先在 all_hidden, pooler = self.bert(x)以及pooled_output = torch.mean(all_hidden, 1)中对于文本进行编码,其次利用linear函数进行预测
def forward(self, x, length=256):
all_hidden, pooler = self.bert(x)
pooled_output = torch.mean(all_hidden, 1)
predict = self.linear(pooled_output)
return predict
下面,我们进入normal_train.py文件分析训练bert模型的过程:
首先,我们在运行代码过程中,发现了一个关键性的错误,在输入数据集时,作者给出的path只是一个象征,而需要我们自己给出数据集的真实路径,额外的,作者在自述文件中也阐明了,利用代码对数据集进行清洗过慢,其已给出详细的完整数据集地址,仅需自行下载并修改地址即可,下图为我自己下载的数据集展示:
下方代码为我自己修改的代码:
parser.add_argument('--data-path', type=str, default='/Users/wuzehao/Desktop/科研/文本分类/MixText-master/data/yahoo_answers_csv/',
help='path to data folders')
再者,进入详细的normal_train.py文件的分析,可见第八篇博客。