PyTorch是很受欢迎的机器学习库,对应库名为torch;
'''使用torch识别简单正态曲线'''
import torch
import numpy as np
from scipy.stats import norm
def create_norm_zero(x_num):
x_base = np.linspace(-6, 6, 100)*x_num
y_base = np.linspace(-3, 3, 100)
y_show = norm.pdf(x_base)/np.max(norm.pdf(y_base))
return [y_show*20]
def create_norm_data():
a = []
norm_data = []
for i in range(1):
num = (i + 1) * 0.23 + 1
norm_data.append(create_norm_zero(num))
a.append(norm_data)
a = np.array(a)
return a
a = torch.ones(1,1,1,100)
b = create_norm_data()
c = torch.from_numpy(b)
x = torch.cat((c, a), 1).type(torch.FloatTensor)
y_t = torch.ones(1,1,50)
y_f = torch.zeros(1,1,50)
y = torch.cat((y_t, y_f), -1).type(torch.LongTensor)
上段为生成训练数据部分,下段为定义模型、训练模型部分;
net = torch.nn.Conv2d(2, 2, 1)
optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
loss_func = torch.nn.CrossEntropyLoss()
for t in range(100):
out = net(x)
prediction = torch.max(out, 1)[1]
print(prediction)
loss = loss_func(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred_y = prediction.data.numpy()
torch.save(net, './model.pth')
model = torch.load('./model.pth')
下段为测试模型预测结果;
test = torch.cat((c, c), 1).type(torch.FloatTensor)
out = model(test)
prediction = torch.max(out, 1)[1]
print(prediction)
compare = torch.tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0]]])
result = torch.eq(prediction, compare)
print(result)