# 使用LSTM模型进行多分类并查看综合分类效果

#!/usr/bin/env python
# encoding: utf-8
'''
@author: taoshouzheng
@contact: tsz1216@sina.com
@file: 1 lstm + linear.py
@time: 2019/11/7 9:15
'''

import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
import math
import random
import numpy as np
from torch import optim
from sklearn.preprocessing import label_binarize
from sklearn.metrics import classification_report

class MyNet(nn.Module):

def __init__(self, input_size, hidden_size, output_size):

super(MyNet, self).__init__()

self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size

self.lstm = nn.LSTM(self.input_size, self.hidden_size)
self.linear = nn.Linear(self.hidden_size, self.output_size)

def init_state(self, batch_size, hidden_size):
h_init = Variable(torch.rand(1, batch_size, hidden_size))
c_init = Variable(torch.rand(1, batch_size, hidden_size))
return h_init, c_init

def forward(self, x, h, c):
output, (new_h, new_c) = self.lstm(x, (h, c))
result = self.linear(new_h)
return result

def prediction(self, x, h, c):
output, (new_h, new_c) = self.lstm(x, (h, c))
result = self.linear(new_h)
result = result.squeeze(0).squeeze(0)
print('11111', result.shape)
label1 = torch.max(result, 1)[1]
print('11111', label1.shape)
return label1

def reset_weights(model):
for weight in model.parameters():
init.constant_(weight, 0.5)

net = MyNet(10, 20, 3)
reset_weights(net)

epoch = 100

# 输入
input = Variable(torch.ones(5, 200, 10))
print('input的形状')
print(input.shape)

# 标签
label = []
for i in range(200):
label.append(random.choice([0, 1, 2]))
label = np.array(label, dtype=np.int)

# label_one_hot = label_binarize(label, np.arange(3))

target = Variable(torch.LongTensor(label))
print('target的形状')
print(target.shape)

h_init, c_init = net.init_state(200, 20)

criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for i in range(100):
output = net(input, h_init, c_init)
output = output.squeeze(0)
_, pred = torch.max(output, 1)

loss = criterion(output, target)
print('epoch', i + 1, ':', loss)

loss.backward()
optimizer.step()

net.eval()

input = Variable(torch.ones(5, 200, 10))
# 标签
label = []
for i in range(200):
label.append(random.choice([0, 1, 2]))
y_test = np.array(label, dtype=np.int)
y_pred = net.prediction(input, h_init, c_init)
print(type(y_pred))
print(y_pred.shape)

ans = classification_report(y_test, y_pred, digits=5)     # digits为输出浮点值的位数，support为每个标签出现的次数
print(ans)

09-18

07-18
04-12
12-03
05-06 1万+
05-17 2万+
12-06
03-11 1万+