Fasttext(AG数据集---新闻主题分类)


Fasttext

在这里插入图片描述

一、文件目录

在这里插入图片描述

二、语料集下载地址(本文选择AG)

AG News: https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz
DBPedia: https://s3.amazonaws.com/fast-ai-nlp/dbpedia_csv.tgz
Sogou news: https://s3.amazonaws.com/fast-ai-nlp/sogou_news_csv.tgz
Yelp Review Polarity: https://s3.amazonaws.com/fast-ai-nlp/yelp_review_polarity_csv.tgz
Yelp Review Full: https://s3.amazonaws.com/fast-ai-nlp/yelp_review_full_csv.tgz
Yahoo! Answers: https://s3.amazonaws.com/fast-ai-nlp/yahoo_answers_csv.tgz
Amazon Review Full: https://s3.amazonaws.com/fast-ai-nlp/amazon_review_full_csv.tgz
Amazon Review Polarity: https://s3.amazonaws.com/fast-ai-nlp/amazon_review_polarity_csv.tgz

三、数据处理(AG_Dataset.py)

1.数据集加载
2.读取标签和数据
3.创建word2id
   3.1统计词频
   3.2加入 pad:0,unk:1创建word2id
4.将数据转化成id

from torch.utils import data
import os
import csv
import nltk
import numpy as np
class AG_Data(data.DataLoader):
    def __init__(self,data_path,min_count,max_length,n_gram=1,word2id = None,uniwords_num=0):
        self.path = os.path.abspath(".")
        if "data" not in self.path:
            self.path += "/data"
        self.n_gram = n_gram
        self.load(data_path)# 数据集加载,读取标签和数据
        if word2id==None:
            self.get_word2id(self.data,min_count)# 得到word2id
        else:
            self.word2id = word2id
            self.uniwords_num = uniwords_num
        self.data = self.convert_data2id(self.data,max_length)# 将文本中的词都转化成id
        self.data = np.array(self.data)
        self.y = np.array(self.y)
    # 数据集加载,读取标签和数据
    def load(self, data_path,lowercase=True):
        self.label = []
        self.data = []
        with open(self.path+data_path,"r") as f:
            datas = list(csv.reader(f,delimiter=',', quotechar='"'))
            for row in datas:
                self.label.append(int(row[0]) - 1)
                txt = " ".join(row[1:])
                if lowercase:
                    txt = txt.lower()
                txt = nltk.word_tokenize(txt)  # 将句子转化成词
                new_txt = []
                for i in range(0, len(txt)):
                    for j in range(self.n_gram):  # 添加n-gram词
                        if j <= i:
                            new_txt.append(" ".join(txt[i - j:i + 1]))
                self.data.append(new_txt)
            self.y = self.label

    # 得到word2id
    def get_word2id(self, datas, min_count=3):
        word_freq = {}
        for data in datas:  # 首先统计词频,后续通过词频过滤低频词
            for word in data:
                if word_freq.get(word) != None:
                    word_freq[word] += 1
                else:
                    word_freq[word] = 1
        word2id = {"<pad>": 0, "<unk>": 1}
        for word in word_freq:  # 首先构建uni-gram词,因为不需要hash
            if word_freq[word] < min_count or " " in word:
                continue
            word2id[word] = len(word2id)
        self.uniwords_num = len(word2id)
        for word in word_freq:  # 构建2-gram以上的词,需要hash
            if word_freq[word] < min_count or " " not in word:
                continue
            word2id[word] = len(word2id)
        self.word2id = word2id

    # 将文本中的词都转化成id
    def convert_data2id(self, datas, max_length):
        for i, data in enumerate(datas):
            for j, word in enumerate(data):
                if " " not in word:
                    datas[i][j] = self.word2id.get(word, 1)
                else:
                    datas[i][j] = self.word2id.get(word, 1) % 100000 + self.uniwords_num  # hash函数
            datas[i] = datas[i][0:max_length] + [0] * (max_length - len(datas[i]))
        return datas
    def __getitem__(self, idx):
        X = self.data[idx]
        y = self.y[idx]
        return X, y
    def __len__(self):
        return len(self.label)
if __name__=="__main__":
    ag_data = AG_Data("/AG/train.csv",3,100)
    print (ag_data.data.shape)
    print (ag_data.data[-20:])
    print (ag_data.y.shape)
    print (len(ag_data.word2id))

四、模型(Fasttext.py)

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Fasttext(nn.Module):
    def __init__(self,vocab_size,embedding_size,max_length,label_num):
        super(Fasttext,self).__init__()
        self.embedding =nn.Embedding(vocab_size,embedding_size)
        self.avg_pool = nn.AvgPool1d(kernel_size=max_length,stride=1)
        self.fc = nn.Linear(embedding_size, label_num)
    def forward(self, x):
        out = self.embedding(x) # batch_size*length*embedding_size bs*100*200
        out = out.transpose(1, 2).contiguous() # batch_size*embedding_size*length bs*200*100
        out = self.avg_pool(out).squeeze() # batch_size*embedding_size*1
        out = self.fc(out) # batch_size*label_num
        return out
if __name__=="__main__":
    fasttext = Fasttext(100,200,100,4)
    x = torch.Tensor(np.zeros([64,100])).long()
    out = fasttext(x)
    print (out.size())

五、训练和测试


import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from model import Fasttext
from data import AG_Data
import numpy as np
from tqdm import tqdm
import config as argumentparser
config = argumentparser.ArgumentParser()
torch.manual_seed(config.seed)

if config.cuda and torch.cuda.is_available():
    torch.cuda.set_device(config.gpu)
def get_test_result(data_iter,data_set):
    # 生成测试结果
    model.eval()
    true_sample_num = 0
    for data, label in data_iter:
        if config.cuda and torch.cuda.is_available():
            data = data.cuda()
            label = label.cuda()
        else:
            data = torch.autograd.Variable(data).long()
        out = model(data)
        true_sample_num += np.sum((torch.argmax(out, 1) == label).cpu().numpy())
    acc = true_sample_num / data_set.__len__()
    return acc
training_set = AG_Data("/AG/train.csv",min_count=config.min_count,
                       max_length=config.max_length,n_gram=config.n_gram)
training_iter = torch.utils.data.DataLoader(dataset=training_set,
                                            batch_size=config.batch_size,
                                            shuffle=True,
                                            num_workers=0)
test_set = AG_Data(data_path="/AG/test.csv",min_count=config.min_count,
                   max_length=config.max_length,n_gram=config.n_gram,word2id=training_set.word2id,
                   uniwords_num=training_set.uniwords_num)
test_iter = torch.utils.data.DataLoader(dataset=test_set,
                                        batch_size=config.batch_size,
                                        shuffle=False,
                                        num_workers=0)
model = Fasttext(vocab_size=training_set.uniwords_num+100000,embedding_size=config.embed_size,
                 max_length=config.max_length,label_num=config.label_num)
if config.cuda and torch.cuda.is_available():
    model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
loss = -1
for epoch in range(config.epoch):
    model.train()
    process_bar = tqdm(training_iter)
    for data, label in process_bar:
        if config.cuda and torch.cuda.is_available():
            data = data.cuda()
            label = label.cuda()
        else:
            data = torch.autograd.Variable(data).long()
        label = torch.autograd.Variable(label).squeeze()
        out = model(data)
        loss_now = criterion(out, autograd.Variable(label.long()))
        if loss == -1:
            loss = loss_now.data.item()
        else:
            loss = 0.95*loss+0.05*loss_now.data.item()
        process_bar.set_postfix(loss=loss_now.data.item())
        process_bar.update()
        optimizer.zero_grad()
        loss_now.backward()
        optimizer.step()
    test_acc = get_test_result(test_iter, test_set)
    print("The test acc is: %.5f" % test_acc)

实验结果

输出测试集准确率:
在这里插入图片描述

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值