Pytorch:循环神经网络-LSTM

Pytorch: 循环神经网络:LSTM进行新闻分类

Copyright: Jingmin Wei, Pattern Recognition and Intelligent System, School of Artificial and Intelligence, Huazhong University of Science and Technology

Pytorch教程专栏链接


本教程不商用,仅供学习和参考交流使用,如需转载,请联系本人。

详细的 LSTM 结构可以参考教程的上篇文章

本文主要是采用门控循环单元网络 LSTM 来进行新闻类别分类,大家也可以尝试把模型改成下篇文章的 GRU 对比两种网络的效果。

使用 THUCNews 数据库进行分类,一共包含 10 10 10 类文本数据,每个类别数据有 6500 6500 6500 条文本,切分为训练集( 5000 × 10 5000\times10 5000×10 )、验证集( 500 × 10 500\times10 500×10 )和测试集( 1000 × 10 1000\times10 1000×10 )

数据集下载链接:http://thuctc.thunlp.org/

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as sns 
import re
import string
import copy
import time
import os
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib
import csv

import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
import torch.utils.data as Data 
import jieba
from torchtext import data
from torchtext.vocab import Vectors
# 输出图显示中文
from matplotlib.font_manager import FontProperties
fonts = FontProperties(fname = 'C:/windows/Fonts/STXIHEI.TTF')
# 模型加载选择GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
cuda
1
GeForce MX250
将文本整合到 train、test、val 三个文件中

数据集划分程序参考:https://github.com/gaussic/text-classification-cnn-rnn

def _read_file(filename):
    """读取一个文件并转换为一行"""
    with open(filename, 'r', encoding='utf-8') as f:
        return f.read().replace('\n', '').replace('\t', '').replace('\u3000', '')

def save_file(dirname):
    """
    将多个文件整合并存到3个文件中
    """
    f_train = open('data/cnews1/cnews.train.txt', 'w', encoding='utf-8')
    f_test = open('data/cnews1/cnews.test.txt', 'w', encoding='utf-8')
    f_val = open('data/cnews1/cnews.val.txt', 'w', encoding='utf-8')
    for category in os.listdir(dirname):   # 分类目录
        cat_dir = os.path.join(dirname, category)
        if not os.path.isdir(cat_dir):
            continue
        files = os.listdir(cat_dir)
        count = 0
        for cur_file in files:
            filename = os.path.join(cat_dir, cur_file)
            content = _read_file(filename)
            if count < 5000:
                f_train.write(category + '\t' + content + '\n')
            elif count < 6000:
                f_test.write(category + '\t' + content + '\n')
            else:
                f_val.write(category + '\t' + content + '\n')
            count += 1

        print('Finished:', category)

    f_train.close()
    f_test.close()
    f_val.close()
save_file('data/thucnews')
print(len(open('data/cnews/cnews.train.txt', 'r', encoding='utf-8').readlines()))
print(len(open('data/cnews/cnews.test.txt', 'r', encoding='utf-8').readlines()))
print(len(open('data/cnews/cnews.val.txt', 'r', encoding
  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值