torch训练模板

如题

 

from text_semantic.model import TextBiLSTM
from text_semantic.dataset import MyDataset
from text_semantic.config import TEMP_PATH, RECORD_PATH

import time
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader


def train_eval(cate, data_loader, model, optimizer, loss_func):
    model.train() if cate == 'train' else model.eval()
    acc, loss_sum = 0.0, 0.0
    for i, (x, target) in enumerate(data_loader):
        x, target = x.cuda(), target.cuda()
        y = model(x)
        loss = loss_func(y, target)
        if cate == 'train':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        acc += y.max(dim=1)[1].eq(target).sum().data
        loss_sum += loss.data
    acc = acc * 100 / len(data_loader.dataset)
    loss_sum = loss_sum / len(data_loader)
    return acc, loss_sum


if __name__ == '__main__':
    num_words = 35131
    num_classes = 20
    embedding_dim = 300
    hidden_size = 100
    dropout = 0.2
    word2vec = np.load(TEMP_PATH + '/done-word2vec.npy')
    padding_len = 400

    start = 45
    batch_size = 512
    lr = 1e-5

    print("init & load...")
    train_data = DataLoader(MyDataset('train', padding_len, num_words), batch_size=batch_size, shuffle=True)
    test_data = DataLoader(MyDataset('test', padding_len, num_words), batch_size=batch_size)

    model = TextBiLSTM(num_words, num_classes, embedding_dim, hidden_size, word2vec, dropout)
    if start != 0: model.load_state_dict(torch.load(RECORD_PATH + '/model.{}.pth'.format(start)))
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    print("start...")
    model = model.cuda()
    for epoch in range(start + 1, 200):
        t1 = time.time()
        train_acc, train_loss = train_eval('train', train_data, model, optimizer, loss_func)
        test_acc, test_loss = train_eval('test', test_data, model, optimizer, loss_func)
        cost = time.time() - t1
        torch.save(model.state_dict(), RECORD_PATH + '/model.{}.pth'.format(epoch))
        print("epoch=%s, cost=%.2f, train:[loss=%.4f, acc=%.2f%%], test:[loss=%.4f, acc=%.2f%%]"
              % (epoch, cost, train_loss, train_acc, test_loss, test_acc))

 

 

 

 

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值