网络层处理文本
class SimpleLSTMBaseline(nn.Module):
def __init__(self, hidden_dim, emb_dim=300, num_linear=3):
super().__init__()
self.embedding=nn.Embedding(len(TEXT.vocab),emb_dim)
self.encoder=nn.LSTM(emb_dim,hidden_dim,num_layers=1)
self.linear_layers=[]
for _ in range(num_linear-1):
self.linear_layers.append(nn.Linear(hidden_dim,hidden_dim))
self.linear_layers=nn.ModuleList(self.linear_layers)
self.predictor=nn.Linear(hidden_dim,10)
def forward(self,seq):
hdn,_=self.encoder(self.embedding(seq))
feature=hdn[-1,:,:]
for layer in self.linear_layers:
feature=layer(feature)
preds=self.predictor(feature)
return preds
这段代码主要包括了embedding, encoder, linear三个操作,将从输入到最后输出,一步步拆解开看里面到底发生了什么。
embedding部分
其中self.embedding来自于__init__中得embedding定义。原始情况下一个样本的表示方式是通过词袋来进行表示,也就是每个词的数值表示范围决定于我们的词典大小(词典中也包含了unk,pad这些),理解为是一个一维空间中数字(序列号)来表示词,而通过embedding之后,我们将这个样本换在了另外一种空间下,使用另外一种方式,即很多个数字(float)来表示每个词(字),好处是可能能够获取到词(字)之间的内在关系。如下为一个小例子帮助理解:假设一个批次的样本,句子长度为101,样本个数为5,当前他们的词(字)是由词典序号进行表示的,接下来我们通过embedding,来将词换一种表示方式
x_sampel2=torch.arange(1,506).reshape((101,5))
vocab_size=len(TEXT.vocab)
emb_dim=300
print(vocab_size)embedding=nn.Embedding(vocab_size, emb_dim)
embed_res=embedding(x_sampel2)
encoder部分
encoder部分用了一个LSTM来对时序特征做"揉合",从第一个timestep到最后一个,我们取最后一个的值就好了。
LSTM层的第一个参数是embedding的维度,第二个是隐藏层维度,第三个是使用几层LSMT(留下的尾巴1)。返回值为每个timestep上的输出值,以及hidden_state,cell_state(留下的尾巴2),我们只需要hdn的最后一层输出就可以了。
hidden_dim=500
encoder=nn.LSTM(emb_dim,hidden_dim,num_layers=2)hdn,(hideen_state,cell_state)=encoder(embed_res)
这个批次一共包含了5个句子,每个句子的长度是101,每个词现在是由500维度的数字(float)来表示的。
linear部分
这里内部使用了nn.Linear,是对500维的数据上做进一步的权重调整,再接一个(dim,class_size)的Linear,就可以将数据由LSTM处理过(通过一个hidden表示的句子特征)以及Linear进一步蹂躏之后,将特征再压缩到Class_size的维度,便于在loss的阶段进行计算。(这里感觉如果是接svm分类器,不进行缩放到Class_size也可以,但是没有操作过。)
hdn,_=encoder(embed_res)
feature=hdn[-1,:,:]
class_size=10
linear=nn.Linear(hidden_dim,hidden_dim)
predict=nn.Linear(hidden_dim,class_size)
如下红框中的部分,是要将来很真实label进行loss计算的。真实label的维度是tensor([5]),如果是tensor([5,1])那就改改维度吧。
训练流程
注意点:
- model.train()指定当前模式为训练
- model.zero_grad()
- 计算preds
- 计算loss(preds, true)
- optim优化器,指定如何降低loss。
github地址:
https://github.com/mathCrazyy/text_classify/