使用LSTM模型进行多分类并查看综合分类效果

#!/usr/bin/env python
# encoding: utf-8
'''
@author: taoshouzheng
@contact: tsz1216@sina.com
@file: 1 lstm + linear.py
@time: 2019/11/7 9:15
'''

import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
from torch.autograd import Variable
import math
import random
import numpy as np
from torch import optim
from sklearn.preprocessing import label_binarize
from sklearn.metrics import classification_report


class MyNet(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):

        super(MyNet, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.lstm = nn.LSTM(self.input_size, self.hidden_size)
        self.linear = nn.Linear(self.hidden_size, self.output_size)

    def init_state(self, batch_size, hidden_size):
        h_init = Variable(torch.rand(1, batch_size, hidden_size))
        c_init = Variable(torch.rand(1, batch_size, hidden_size))
        return h_init, c_init

    def forward(self, x, h, c):
        output, (new_h, new_c) = self.lstm(x, (h, c))
        result = self.linear(new_h)
        return result

    def prediction(self, x, h, c):
        output, (new_h, new_c) = self.lstm(x, (h, c))
        result = self.linear(new_h)
        result = result.squeeze(0).squeeze(0)
        print('11111', result.shape)
        label1 = torch.max(result, 1)[1]
        print('11111', label1.shape)
        return label1


def reset_weights(model):
    for weight in model.parameters():
        init.constant_(weight, 0.5)


net = MyNet(10, 20, 3)
reset_weights(net)

epoch = 100

# 输入
input = Variable(torch.ones(5, 200, 10))
print('input的形状')
print(input.shape)

# 标签
label = []
for i in range(200):
    label.append(random.choice([0, 1, 2]))
label = np.array(label, dtype=np.int)

# label_one_hot = label_binarize(label, np.arange(3))

target = Variable(torch.LongTensor(label))
print('target的形状')
print(target.shape)

h_init, c_init = net.init_state(200, 20)

criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for i in range(100):
    output = net(input, h_init, c_init)
    output = output.squeeze(0)
    _, pred = torch.max(output, 1)


    loss = criterion(output, target)
    print('epoch', i + 1, ':', loss)

    loss.backward()
    optimizer.step()

net.eval()

input = Variable(torch.ones(5, 200, 10))
# 标签
label = []
for i in range(200):
    label.append(random.choice([0, 1, 2]))
y_test = np.array(label, dtype=np.int)
y_pred = net.prediction(input, h_init, c_init)
print(type(y_pred))
print(y_pred.shape)

ans = classification_report(y_test, y_pred, digits=5)     # digits为输出浮点值的位数,support为每个标签出现的次数
print(ans)

 

©️2020 CSDN 皮肤主题: 精致技术 设计师:CSDN官方博客 返回首页