参考代码https://github.com/L1aoXingyu/code-of-learn-deep-learning-with-pytorch
import torch
from torch import nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
def plot_decision_boundary(model,x,y):
#set min and max values and give it some padding
x_min,x_max=x[:,0].min() - 1,x[:,0].max()+1
y_min,y_max=x[:,1].min() - 1,x[:,1].max()+1
h=0.01
xx,yy= np.meshgrid(np.arange(x_min,x_max,h),np.arange(y_min,y_max,h))
# Predict the function value for the whole grid
#画出整个区域的分区点z是模型得到的类别
Z=model(np.c_[xx.ravel(),yy.ravel()])#c_增加列,ravel拉平(类似引用)
Z=Z.reshape(xx.shape)
# Plot the contour and training examples
#cmap=plt.cm.Spectral不同类别画不同颜色
plt.contourf(xx,yy,Z,cmap=plt.cm.Spectral)
plt.ylabel('x2')
plt.xlabel('x1')
plt.scatter(x[:,0],x[:,1],c=y.reshape(-1),s=40,cmap=plt.cm.Spectral)
class module_net(nn.Module):
def __init__(self,num_input,num_hidden,num_output):
super(module_net,self).__init__()
self.layer1 = nn.Linear(num_input,num_hidden)
self.layer2 = nn.Tanh()
self.layer3 = nn.Linear(num_hidden,num_output)
def forward(self,x):
x=self.layer1(x)
x=self.layer2(x)
x=self.layer3(x)
return x
mo_net=module_net(2,4,1)
#第一层
L1=mo_net.layer1
print(L1)
print(L1.weight)
#定义优化器
optim = torch.optim.SGD(mo_net.parameters(),1.)
criterion = nn.BCEWithLogitsLoss()
np.random.seed(1)
m=400
N=int(m/2) #每一类的点数
D=2#维度
x=np.zeros((m,D))
y=np.zeros((m,1),dtype='uint8')
a=4
for j in range(2):
ix=range(N*j,N*(j+1))
t=np.linspace(j*3.12,(j+1)*3.12,N)+np.random.randn(N)*0.2
r=a*np.sin(4*t)+np.random.randn(N)*0.2
x[ix]=np.c_[r*np.sin(t),r*np.cos(t)]
y[ix]=j
x=torch.from_numpy(x).float()
y=torch.from_numpy(y).float()
#训练1000次
for e in range(10000):
out = mo_net(Variable(x))
loss =criterion(out,Variable(y))
optim.zero_grad()
loss.backward()
optim.step()
if (e+1)%1000==0:
print('epoch:{},loss:{}'.format(e+1,loss.item()))
import torch.nn.functional as F
def plot_net(x):
out=F.sigmoid(mo_net(Variable(torch.from_numpy(x).float()))).data.numpy()
out=(out>0.5)*1
return out
plot_decision_boundary(lambda x:plot_net(x),x.numpy(),y.numpy())
plt.title('sequential')
plt.show()