# 代码与公式的结合理解
## 字嵌入层
```python
self.char_embedding = nn.Embedding(args.char_alphabet_size, self.char_emb_dim)
```
## 词嵌入层
```python
if self.use_edge:
# word embedding
self.word_embedding = nn.Embedding(args.word_alphabet_size, self.word_emb_dim)
if data.pretrain_word_embedding is not None:
scale = np.sqrt(3.0 / self.word_emb_dim)
data.pretrain_word_embedding[0, :] = np.random.uniform(-scale, scale, [1, self.word_emb_dim])
self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))
# bmes embedding
self.bmes_embedding = nn.Embedding(4, self.bmes_dim)
```
## LSTM层
```python
self.emb_rnn_f = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True)
self.emb_rnn_b = nn.LSTM(self.char_emb_dim, self.hidden_dim, batch_first=True)
# length embedding
self.length_embedding = nn.Embedding(self.max_word_length, self.length_dim)
self.dropout = nn.Dropout(self.emb_dropout_rate)
self.norm = nn.LayerNorm(self.hidden_dim)
```
## 节点聚合
```python
self.edge2node_f = nn.ModuleList(
[MultiHeadAtt(self.hidden_dim, self.hidden_dim * 2 + self.length_dim,
nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
```
![file](https://wp.recgroup.cn/wp-content/uploads/2021/01/image-1611204058158.png)
## 边聚合
```python
self.node2edge_f = nn.ModuleList(
[MultiHeadAtt(self.hidden_dim, self.hidden_dim + self.bmes_dim, nhead=self.num_head,
head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
```
![file](https://wp.recgroup.cn/wp-content/uploads/2021/01/image-1611204076080.png)
MultiHeadAtt定义与module.py文件中,包含四个卷积层,一个Dropout和一个LayerNorm层
## 全局聚合
```python
self.glo_att_f_node = nn.ModuleList(
[GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
if self.use_edge:
self.glo_att_f_edge = nn.ModuleList(
[GloAtt(self.hidden_dim, nhead=self.num_head, head_dim=self.head_dim, dropout=self.tf_dropout_rate)
for _ in range(self.iters)])
# Updating modules
if self.use_edge:
self.glo_rnn_f = Global_Cell(self.hidden_dim * 3, self.hidden_dim, dropout=self.cell_dropout_rate)
self.node_rnn_f = Nodes_Cell(self.hidden_dim * 5, self.hidden_dim, dropout=self.cell_dropout_rate)
self.edge_rnn_f = Edges_Cell(self.hidden_dim * 4, self.hidden_dim, dropout=self.cell_dropout_rate)
else:
self.glo_rnn_f = Global_Cell(self.hidden_dim * 2, self.hidden_dim, dropout=self.cell_dropout_rate)
self.node_rnn_f = Nodes_Cell(self.hidden_dim * 4, self.hidden_dim, dropout=self.cell_dropout_rate)
```
![file](https://wp.recgroup.cn/wp-content/uploads/2021/01/image-1611204096777.png)
## 可选的CRF层
```python
if self.use_crf:
self.hidden2tag = nn.Linear(output_dim, self.label_size + 2)
self.crf = CRF(self.label_size, self.gpu)
else:
self.hidden2tag = nn.Linear(output_dim, self.label_size)
self.criterion = nn.CrossEntropyLoss()
```
CRF类定义与crf.py文件中,代码较多,之后画个图吧