20Newsgroups的部分数据集,一个四分类
model.py
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNN_Text(nn.Module):
def __init__(self, args):
super(CNN_Text, self).__init__()
self.args = args
embed_num = args.embed_num
embed_dim = args.embed_dim
class_num = args.class_num
Ci = 1
kernel_num = args.kernel_num
kernel_sizes = args.kernel_sizes
self.embed = nn.Embedding(embed_num, embed_dim)
self.convs_list = nn.ModuleList(
[nn.Conv2d(Ci, kernel_num, (kernel_size, embed_dim)) for kernel_size in kernel_sizes])
self.dropout = nn.Dropout(args.dropout)
self.fc = nn.Linear(len(kernel_sizes) * kernel_num, class_num)
def forward(self, x):
x = self.embed(x)
x = x.unsqueeze(1)
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs_list]
x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
x = torch.cat(x, 1)
x = self.dropout(x)
x = x.view(x.size(0), -1)
logit = self.fc(x)
return logit
mydataset.py
import os
import re
import tarfile
import urllib
from torchtext import data
class TarDataset(data.Dataset):
"""Defines a Dataset loaded from a downloadable tar archive.
Attributes:
url: URL where the tar archive can be downloaded.
filename: Filename of the downloaded tar archive.
dirname: Name of the top-level directory within the zip archive that
contains the data files.
"""
@classmethod
def download_or_unzip(cls, root):
path = os.path.join(root, cls.dirname)
if not os.path.isdir(path):
tpath = os.path.join(root, cls.filename)
if not os.path.isfile(tpath):
print('downloading')
urllib.request.urlretrieve(cls.url, tpath)
with tarfile.open(tpath, 'r') as tfile:
print('extracting')
tfile.extractall(root)
return os.path.join(path, '')
class NEWS_20(TarDataset):
url = 'http://people.csail.mit.edu/jrennie/20Newsgroups/20news-bydate.tar.gz'
filename = 'data/20news-bydate-train'
dirname = ''
@staticmethod#以后重构类的时候不必要修改构造函数,只需要额外添加你要处理的函数,然后使用装饰符 @classmethod 就可以了
def sort_key(ex):
return len(ex.text)
def __init__(self, text_field, label_field, path=None, text_cnt=1000, examples=None, **kwargs):
"""Create an MR dataset instance given a path and fields.
Arguments:
text_field: The field that will be used for text data.
label_field: The field that will be used for label data.
path: Path to the data file.
examples: The examples contain all the data.
Remaining keyword arguments: Passed to the constructor of
data.Dataset.
"""
def clean_str(string):
"""
Tokenization/string cleaning for all datasets except for SST.
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
"""
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
string = re.sub(r"\'s", " \'s", string)
string = re.sub(r"\'ve", " \'ve", string)
string = re.sub(r"n\'t", " n\'t", string)
string = re.sub(r"\'re", " \'re", string)
string = re.sub(r"\'d", " \'d", string)
string = re.sub(r"\'ll", " \'ll", string)
string = re.sub(r",", " , ", string)
string = re