6. 训练模型
以下使用PyTorch
来定义和训练模型,由于篇幅原因,这里省略了很多代码,只贴出了部分关键代码,可执行的完整代码可参照src/3Model/joint_embedder.py
。
6.1 定义模型
从前面模型细节图中可以看到,我们对方法名、API序列、描述都使用了同样的编码器,我们可以针对序列数据统一定义一个编码器:
class SeqEncoder(nn.Module): def __init__(self, vocab_size, embed_size, hidden_size, rnn='lstm', bidirectional=True, pool='max', activation='tanh'): self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0) self.rnn = rnn_choices[rnn](embed_size, hidden_size, bidirectional) self.pool = pool_choices[pool] self.activation = activations_choices[activation] def forward(self, input): embedded = F.dropout(self.embedding(input), 0.25, self.training) rnn_output = F.dropout(self.rnn(embedded)[0], 0.25, self.training) return self.activation(self.pool(rnn_output, dim=1))
同样的,针对标识符使用的MLP,我们也定义一个编码器:
class BOWEncoder(nn.Module): def __init__(self, vocab_size, embed_size, pool='max', activation='tanh'): self.embedding = nn.Embedding(vocab_size, embed_size) self.pool = pool_choices[pool] self.activation = activations_choices[activation] def forward(self, input): embedded = F.dropout(self.embedding(input), 0.25, self.training) return self.activation(self.pool(embedded, dim=1))
这样,我们的代码编码网络CoNN
可以定义为:
def forward_code(self, name, apis, tokens): name_repr = SeqEncoder(name) apis_repr = SeqEncoder(apis) tokens_repr = BOWEncoder(tokens) code_repr = nn.Linear(torch.cat((name_repr, apis_repr, tokens_repr), 1)) return torch.tanh(code_repr)
描述编码网络DeNN
可以定义为:
def forward_desc(self, desc): return SeqEncoder(desc)
6.2 训练模型
真实训练时,我们给每个代码段两个描述,一个是原始描述,即正确的描述,另一个是随机描述,即错误描述,然后选取下面的值作为loss:
def forward(self, name, apis, tokens, desc_good, desc_bad): code_repr = forward_code(name, apis, tokens) good_sim = F.cosine_similarity(code_repr, forward_desc(desc_good)) bad_sim = F.cosine_similarity(code_repr, forward_desc(desc_bad)) return (margin - good_sim + bad_sim).clamp(min=1e-6)
具体的代码在src/3Model/codenn.py
中,对应的训练命令为
python src/3Model/codenn.py --dataset_path data/py/final --model_p