# coding: UTF-8import torch
import torch.nn as nn
import torch.nn.functional as F
"""上接配置参数信息"""classModel(nn.Module):def__init__(self, config):super(Model, self).__init__()
self.embedding = nn.Embedding(config.n_vocab, config.embed,
padding_idx=config.n_vocab -1)
self.convs = nn.ModuleList([nn.Conv2d(1, config.num_filters,(k, config.embed))for k in config.filter_sizes])
self.dropout = nn.Dropout(config.dropout)
self.fc = nn.Linear(config.num_filters *len(config.filter_sizes), config.num_classes)defconv_and_pool(self, x, conv):
x = F.relu(conv(x)).squeeze(3)
x = F.max_pool1d(x, x.size(2)).squeeze(2)return x
defforward(self, x):
out = self.embedding(x[0])
out = out.unsqueeze(1)
out = torch.cat([self.conv_and_pool(out, conv)for conv in self.convs],1)
out = self.dropout(out)
out = self.fc(out)return out
步骤二:批量加载数据load_data_iter.py
# coding: UTF-8from tqdm import tqdm
from unit25.TextCNN import Config
MAX_VOCAB_SIZE =10000# 词表长度限制
UNK, PAD ='<UNK>','<PAD>'# 未知字,padding符号defbuild_vocab(file_path, tokenizer, max_size, min_freq):
vocab_dic ={}withopen(file_path,'r', encoding='UTF-8')as f:for line in tqdm(f):
lin = line.strip()ifnot lin:continue
content = lin.split('\t')[0]for word in tokenizer(content):# 统计每个字出现的频数
vocab_dic[word]= vocab_dic.get(word,0)+1# 按照频数对字典进行倒序排序
vocab_list =sorted([_ for _ in vocab_dic.items()if _[1]>= min_freq],
key=lambda x: x[1], reverse=True)[:max_size]
vocab_dic ={word_count[0]: idx for idx, word_count inenumerate(vocab_list)}
vocab_dic.update({UNK:len(vocab_dic), PAD:len(vocab_dic)+1})return vocab_dic