![](https://img-blog.csdnimg.cn/20201014180756930.png?x-oss-process=image/resize,m_fixed,h_64,w_64)
小黑笔记之DialogueRNN
爱喝喜茶爱吃烤冷面的小黑黑
这个作者很懒,什么都没留下…
展开
-
小黑代码笔记之dialogueRNN训练笔记
import torchimport torch.nn as nnimport torch.optim as optimimport datetime as dtimport osfrom torch.utils.data import DataLoaderimport numpy as np,pickle,time,argparsefrom dataloader import DailyDialoguePadCollate,DailyDialogueDatasetfrom model im原创 2022-02-22 17:41:33 · 370 阅读 · 1 评论 -
小黑维度逐行分析与调试:dataloader.py
1.datasetimport torchfrom torch.utils.data import Datasetfrom torch.nn.utils.rnn import pad_sequenceimport pickleimport osimport reimport pandas as pdclass DailyDialogueDataset(Dataset): def __init__(self,split,path): self.Speakers,sel原创 2022-02-22 14:42:33 · 258 阅读 · 0 评论 -
小黑维度逐行分析与调试:SimpleAttention、MaskedNLLLoss
1.MaskedNLLLossimport torchimport torch.nn as nnclass MaskedNLLLoss(nn.Module): def __init__(self,weight = None): super(MaskedNLLLoss,self).__init__() self.weight = weight self.loss = nn.NLLLoss(weight = weight,reduction = 's原创 2022-02-21 23:09:05 · 387 阅读 · 0 评论 -
小黑维度逐行分析与调试:MatchingAttention
import torchimport torch.nn as nnimport torch.nn.functional as Fclass MatchingAttention(nn.Module): def __init__(self,mem_dim,cand_dim,alpha_dim = None,att_type = 'general'): super(MatchingAttention,self).__init__() assert att_type .原创 2022-02-21 20:45:15 · 234 阅读 · 0 评论 -
小黑维度逐行分析与调试:DialogueRNNCell
模型总体图g_cellp_celle_cell模型代码import torch.nn as nnimport torchfrom model import SimpleAttentionclass DialogueRNNCell(nn.Module): def __init__(self,D_m,D_g,D_p,D_e,context_attention = 'simple',D_a = 100,dropout = 0.5): super(Dialog原创 2022-02-21 13:49:56 · 728 阅读 · 0 评论 -
小黑维度逐行分析与调试:DialogueRNN
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.nn.utils.rnn import pad_sequencefrom model import DialogueRNN,MatchingAttention,DialogueRNNCellclass DialogueRNN(nn.Module): def __init__(self,D_m,D_g,D_p,D_e,context_atten原创 2022-02-20 23:14:31 · 303 阅读 · 0 评论 -
小黑维度逐行分析与调试:DailyDialogueModel
1.数据准备import torchimport torch.nn as nnimport torch.optim as optimimport datetime as dtimport osfrom torch.utils.data import DataLoaderimport numpy as np, pickle, time, argparsefrom model import DailyDialogueModel, MaskedNLLLossfrom dataloader imp原创 2022-02-20 19:50:07 · 660 阅读 · 1 评论 -
小黑之GRUCell的demo
输入:input: [batch, input_size] hidden: [batch, hidden_size] 输出:h′:[batch,hidden_size] 参数:GRUCell.weight_ih: [3 x hidden_size, input_size] GRUCell.weight_hh: [3 x hidden_size, hidden_size] GRUCell.bias_ih: [3 x hidden_size] GRUCell.bias_hh: [3 x ..原创 2022-02-18 14:59:20 · 988 阅读 · 0 评论