小黑维度逐行分析与调试:DialogueRNN

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from model import DialogueRNN,MatchingAttention,DialogueRNNCell
class DialogueRNN(nn.Module):
    def __init__(self,D_m,D_g,D_p,D_e,context_attention = 'simple',D_a = 100,dropout = 0.5):
        super(DialogueRNN,self).__init__()
        self.D_m = D_m    # D_m:sentence embedding
        self.D_g = D_g    # D_g:global context size vector
        self.D_p = D_p    # D_p:party's state
        self.D_e = D_e    # D_e:emotion's represent
        self.dropout = nn.Dropout(dropout)
        # D_a:attention维度
        self.dialogue_cell = DialogueRNNCell(D_m,D_g,D_p,D_e,context_attention,D_a,dropout)
    def forward(self,U,qmask):
        # U:[num_seqs,batch_size,D_m]      q_mask:[num_seq,batch_size,num_party]
        g_hist = torch.zeros(0).type(U.type())    # []
        # q_:[batch_size,num_party,D_p]
        q_ = torch.zeros(qmask.size()[1],qmask.size()[2],self.D_p).type(U.type())
        # e_:[]
        e_ = torch.zeros(0).type(U.type())
        e = e_
        alpha = []
        for u_,qmask_ in zip(U,qmask):
            # u_:[batch_size,D_m]    qmask_:[batch_size,num_party]
            # g_hist:
            # 输入q_:[batch_size,num_party,D_p]
            # 输入e_:t = 1时 [] else [batch_size,D_e]
            # 输出q_:[batch_size,num_party,D_p]
            # 输出e_:[batch_size,D_e]
            # alpha_:[batch_size,1,t]
            g_,q_,e_,alpha_ = self.dialogue_cell(u_,qmask_,g_hist,q_,e_)
            # g_hist:[t,batch_size,D_g]
            g_hist = torch.cat([g_hist,g_.unsqueeze(0)],0)
            # e:[t,batch_size,D_e]
            e = torch.cat([e,e_.unsqueeze(0)],0)
            if type(alpha_) != type(None):
                alpha.append(alpha_[:,0,:])
        # e:[seq_len,batch_size,D_e]
        # alpha:(seq_len-1)个 [batch_size,t-1], 2<=t<=seq_len
        return e,alpha
# 测试函数
D_m = 100
D_g = 150
D_p = 150
D_e = 100
model = DialogueRNN(D_m,D_g,D_p,D_e)
batch_size = 5
num_seqs = 20    # 20个句子

U = torch.randn(num_seqs,batch_size,D_m)
qmask = torch.LongTensor([[[0,1],[1,0]]* 10]*5).transpose(0,1)
e,alpha = model(U,qmask)
print('参数设置:')
print('batch_size:',batch_size,'\nnum_seqs:',num_seqs)
print('打印维度:')
print('e:',e.shape)
print('alpha长度:',len(alpha),'\nalpha[0]:',alpha[0].shape)
print('alpha[1]:',alpha[1].shape,'...')

输出:

参数设置:
batch_size: 5
num_seqs: 20
打印维度:
e: torch.Size([20, 5, 100])
alpha长度: 19
alpha[0]: torch.Size([5, 1])
alpha[1]: torch.Size([5, 2]) …

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值