import torch
import torchtext
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchtext.vocab import GloVe
import time
start=time.time()
#每篇提取200个单词
TEXT = torchtext.data.Field(lower=True, fix_length=200, batch_first=False)
LABEL = torchtext.data.Field(sequential=False)
train, test = torchtext.datasets.IMDB.splits(TEXT, LABEL)
TEXT.build_vocab(train, max_size=10000, min_freq=10, vectors=None)
LABEL.build_vocab(train)
BATCHSIZE = 256
train_iter, test_iter = torchtext.data.BucketIterator.splits((train, test), batch_size=BATCHSIZE)
embeding_dim = 100
hidden_size = 300
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.em = nn.Embedding(len(TEXT.vocab.stoi), embeding_dim)
self.lstm = nn.LSTM(embeding_dim, hidden_size)
self.fc1 = nn.Linear(hidden_size, 256)
self.fc2 = nn.Linear(256, 3)
def for
基于pytorch的LSTM的简单文本分类
最新推荐文章于 2024-07-25 14:52:49 发布
![](https://img-home.csdnimg.cn/images/20240711042549.png)