AI基于深度学习的代码搜索案例(二)

 6. 训练模型

以下使用PyTorch来定义和训练模型,由于篇幅原因,这里省略了很多代码,只贴出了部分关键代码,可执行的完整代码可参照src/3Model/joint_embedder.py

6.1 定义模型

从前面模型细节图中可以看到,我们对方法名、API序列、描述都使用了同样的编码器,我们可以针对序列数据统一定义一个编码器:

class SeqEncoder(nn.Module):
  def __init__(self, vocab_size, embed_size, hidden_size, rnn='lstm',
               bidirectional=True, pool='max', activation='tanh'):
    self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
    self.rnn = rnn_choices[rnn](embed_size, hidden_size, bidirectional)
    self.pool = pool_choices[pool]
    self.activation = activations_choices[activation]

  def forward(self, input):
    embedded = F.dropout(self.embedding(input), 0.25, self.training)
    rnn_output = F.dropout(self.rnn(embedded)[0], 0.25, self.training)
    return self.activation(self.pool(rnn_output, dim=1))

同样的,针对标识符使用的MLP,我们也定义一个编码器:

class BOWEncoder(nn.Module):
  def __init__(self, vocab_size, embed_size, pool='max', activation='tanh'):
    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.pool = pool_choices[pool]
    self.activation = activations_choices[activation]

  def forward(self, input):
    embedded = F.dropout(self.embedding(input), 0.25, self.training)
    return self.activation(self.pool(embedded, dim=1))

这样,我们的代码编码网络CoNN可以定义为:

def forward_code(self, name, apis, tokens):
  name_repr = SeqEncoder(name)
  apis_repr = SeqEncoder(apis)
  tokens_repr = BOWEncoder(tokens)
  code_repr = nn.Linear(torch.cat((name_repr, apis_repr, tokens_repr), 1))
  return torch.tanh(code_repr)

描述编码网络DeNN可以定义为:

def forward_desc(self, desc):
  return SeqEncoder(desc)

6.2 训练模型

真实训练时,我们给每个代码段两个描述,一个是原始描述,即正确的描述,另一个是随机描述,即错误描述,然后选取下面的值作为loss:

def forward(self, name, apis, tokens, desc_good, desc_bad):
  code_repr = forward_code(name, apis, tokens)
  good_sim = F.cosine_similarity(code_repr, forward_desc(desc_good))
  bad_sim = F.cosine_similarity(code_repr, forward_desc(desc_bad))
  return (margin - good_sim + bad_sim).clamp(min=1e-6)

具体的代码在src/3Model/codenn.py中,对应的训练命令为

python src/3Model/codenn.py --dataset_path data/py/final --model_p
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值