#!/usr/bin/env python
# coding: utf-8
# In[11]:
get_ipython().run_line_magic('matplotlib', 'inline')
import torch
import random
import numpy as np
from IPython import display
from matplotlib import pyplot as plt
# In[12]:
# 数据集
data = np.array([[1, 0, -1], [0, 1, -1],
[1, 1, 1], [0, 0, 1]], dtype='float32')
data = torch.from_numpy(data)
features = data[:,0:2]
labels = data[:,2]
print(features,labels)
# In[ ]:
# In[13]:
# 读取数据 简易方法
import torch.utils.data as Data
# 小批量处理
batch_size = 2
dataset = Data.TensorDataset(features,labels)
# 特征和标签组合
data_iter = Data.DataLoader(dataset,batch_size,shuffle=True)
# 随机读取小批量
for X,y in data_iter:
print(X,y)
break
# In[14]:
#简易搭建方法
import torch.nn as nn
net = torch.nn.Sequential(
torch.nn.Linear(2, 2),
torch.nn.ReLU(),
torch.nn.Linear(2, 1)
)
print(net)
# In[15]:
nn.init.normal_(net[0].weight,mean=0 ,std= 0.01)
nn.init.constant_(net[0].bias,val=0 )
nn.init.normal_(net[2].weight,mean=0 ,std= 0.01)
nn.init.constant_(net[2].bias,val=0 )
loss = nn.MSELoss()
import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.03)
print(optimizer)
num_epochs = 1000
for epoch in range(1, num_epochs + 1):
for X, y in data_iter:
output = net(X)
l = loss(output, y.view(-1, 1))
optimizer.zero_grad() # 梯度清零,等价于net.zero_grad()
l.backward()
optimizer.step()
print('epoch %d, loss: %f' % (epoch, l.item()))
i=0
while i<=2:
dense = net[i]
print( i,'的weight',dense.weight)
print( i,'的bias',dense.bias)
i+=2
test = net(features)
print('最后结果是:',test)
# In[16]:
# 测试
W1 =torch.tensor([[-1.6941, 1.6941],
[-1.0862, 1.0863]])
print(W1,W1.size())
W2 = torch.tensor([[-2.3606, 1.8405]])
print(W2,W2.size())
B1 = torch.tensor([-1.3022e-05, 1.0862e+00])
B2 = torch.tensor([-0.9993])
for i in range(0,4):
# print(features.size())
# print(features[i,:])
temp = torch.mm(W1,features[i,:].view(2,1))
# print('temp',temp.size())
temp = temp+B1.view(2,1)
# print('temp',temp.size())
for j in range(len(temp)):
if temp[j]<0:
temp[j]=0
ans = torch.mm(W2,temp)+B2
print(i,'ans:',ans)
# In[21]:
import os
print(os.path.abspath('.'))
解决异或问题,只有4个数据点,样本有点少,所以收敛效果每次有差别,波动比较大,多运行几次就ok了。