关于《黑马程序员》课程中NLP中 训练新闻分类模型
最近在学习NLP的相关知识,找了资料比较全的黑马程序员中讲解NLP的课程,可是其中有一部分实战 新闻主题分类实战项目中,我发现黑马程序员代码有大两的错误,多处代码逻辑错误:
- 首先是数据集下载太慢,因为需要翻墙才能下载,所以大部分情况在加载数据集就会出现Timeout异常
- 数据集的处理,在课程中并没有提到,加载本地的csv数据集文件出现的格式不对的情况
- 其次,generator_banth()这个方法中返回的数据对象元组形式是不对的,新闻数据集的元组是3项(type, title ,content)分别是新闻的类型,新闻的标题和新闻的内内容,但是在课程却只有两项。
!!!需要注意的是 torchtext 的版本是0.4 ,可能是版本更新后,这个模块被移走了,如果不是0.4 可能会出现from torchtext.datasets.text_classification 这句话错误!!!
针对上述问题,我整理了一个完整可以正常运行的完整代码,希望给个小心心或者关注我一下呀~
先放代码
from torchtext.datasets.text_classification import *
from torchtext.datasets.text_classification import _csv_iterator, _create_data_from_iterator
import os
import time
from torch import optim
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch.nn as nn
from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizer
N_GRAMS =2
if not os.path.isdir('./data'):
os.mkdir('./data')
BATCH_SIZE = 16
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
# 定义创建数据集
def _setup_data_set(dataset_tar='./data/ag_news_csv.tar.gz',
n_grams=N_GRAMS, vocab=None,
include_unk=False):
extracted_files = extract_archive(dataset_tar)
train_csv_path = ''
test_csv_path = ''
for file_name in extracted_files:
if file_name.endswith('train.csv'):
train_csv_path = file_name
if file_name.endswith('test.csv'):
test_csv_path=file_name
if vocab is None:
print("Building Vocab based on %s" % train_csv_path)
# 创建词典
vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams=n_grams))
else:
if not isinstance(vocab, Vocab):
raise TypeError("Passed vocabulary is not of type Vocab")
print('Vocab has %d entries' % len(vocab))
print('Creating training data')
train_data, train_labels = _create_data_from_iterator(
vocab, _csv_iterator(test_csv_path