如题
from text_semantic.model import TextBiLSTM
from text_semantic.dataset import MyDataset
from text_semantic.config import TEMP_PATH, RECORD_PATH
import time
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
def train_eval(cate, data_loader, model, optimizer, loss_func):
model.train() if cate == 'train' else model.eval()
acc, loss_sum = 0.0, 0.0
for i, (x, target) in enumerate(data_loader):
x, target = x.cuda(), target.cuda()
y = model(x)
loss = loss_func(y, target)
if cate == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc += y.max(dim=1)[1].eq(target).sum().data
loss_sum += loss.data
acc = acc * 100 / len(data_loader.dataset)
loss_sum = loss_sum / len(data_loader)
return acc, loss_sum
if __name__ == '__main__':
num_words = 35131
num_classes = 20
embedding_dim = 300
hidden_size = 100
dropout = 0.2
word2vec = np.load(TEMP_PATH + '/done-word2vec.npy')
padding_len = 400
start = 45
batch_size = 512
lr = 1e-5
print("init & load...")
train_data = DataLoader(MyDataset('train', padding_len, num_words), batch_size=batch_size, shuffle=True)
test_data = DataLoader(MyDataset('test', padding_len, num_words), batch_size=batch_size)
model = TextBiLSTM(num_words, num_classes, embedding_dim, hidden_size, word2vec, dropout)
if start != 0: model.load_state_dict(torch.load(RECORD_PATH + '/model.{}.pth'.format(start)))
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
print("start...")
model = model.cuda()
for epoch in range(start + 1, 200):
t1 = time.time()
train_acc, train_loss = train_eval('train', train_data, model, optimizer, loss_func)
test_acc, test_loss = train_eval('test', test_data, model, optimizer, loss_func)
cost = time.time() - t1
torch.save(model.state_dict(), RECORD_PATH + '/model.{}.pth'.format(epoch))
print("epoch=%s, cost=%.2f, train:[loss=%.4f, acc=%.2f%%], test:[loss=%.4f, acc=%.2f%%]"
% (epoch, cost, train_loss, train_acc, test_loss, test_acc))