【pytorch模型实现12】文本匹配之DSSM

24 篇文章 0 订阅
24 篇文章 2 订阅

DSSM模型结构

在这里插入图片描述

代码地址:https://github.com/lyj157175/Models
import torch 
import torch.nn as nn


class DSSM(nn.Module):

    def __init__(self, vocab_size, embedding_dim, dropout):
        super(DSSM, self).__init__()
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.fc1 = nn.Linear(embedding_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.dropout = nn.Dropout(dropout)

    def forward(self, a, b):
        a = self.embed(a).sum(1)
        b = self.embed(b).sum(1)

        a = self.dropout(torch.tanh(self.fc1(a)))
        a = self.dropout(torch.tanh(self.fc2(a)))
        a = self.dropout(torch.tanh(self.fc3(a)))

        b = self.dropout(torch.tanh(self.fc1(b)))
        b = self.dropout(torch.tanh(self.fc2(b)))
        b = self.dropout(torch.tanh(self.fc3(b)))

        cosine = torch.cosine_similarity(a, b, dim=1, eps=1e-8)   # 计算两个句子的余弦相似度
        return cosine
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)


if __name__ =='__main__':
    model = DSSM(30, 100, 0.2)
    model._init_weights()
    print(model)
    
  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值