pytorch-单层神经网络

25 篇文章 1 订阅
11 篇文章 2 订阅

1.监督学习与非监督学习的区别

https://www.zhihu.com/question/27138263/answer/635004780

2.实现简单的 y = wx + b 模型

from torch.autograd import Variable
import torch
x_train = Variable(torch.linspace(1, 10, 10), requires_grad=True)
y_train = Variable(torch.linspace(10, 1, 10), requires_grad=True)
x_train, y_train
(tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.], requires_grad=True),
 tensor([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.], requires_grad=True))
w = Variable(torch.randn(1), requires_grad=True)
b = Variable(torch.zeros(1), requires_grad=True)
def linear_model(x):
    return w * x + b
def get_loss(y_, y):
    return torch.mean((y_ - y) ** 2)
for i in range(1000):
    y_ = linear_model(x_train)
    loss = get_loss(y_, y_train)
    
    # w,b梯度归零
    if i != 0:
        w.grad.data.zero_()
        b.grad.data.zero_()
    
    loss.backward()
    # 更新w
    w.data = w.data - 1e-2 * w.grad.data
    # 更新b
    b.data = b.data - 1e-2 * b.grad.data
    
    if i % 200 == 0:
        print('epoch: {}, loss: {}'.format(i, loss.data))
epoch: 0, loss: 43.67679977416992
epoch: 200, loss: 4.732393264770508
epoch: 400, loss: 0.8790384531021118
epoch: 600, loss: 0.16328111290931702
epoch: 800, loss: 0.030329588800668716

3.Logistic回归模型

返回0-1之间的概率
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
data = [(34.62365962451697, 78.0246928153624, 0.0),
 (30.28671076822607, 43.89499752400101, 0.0),
 (35.84740876993872, 72.90219802708364, 0.0),
 (60.18259938620976, 86.30855209546826, 1.0),
 (79.0327360507101, 75.3443764369103, 1.0),
 (45.08327747668339, 56.3163717815305, 0.0),
 (61.10666453684766, 96.51142588489624, 1.0),
 (75.02474556738889, 46.55401354116538, 1.0),
 (76.09878670226257, 87.42056971926803, 1.0),
 (84.43281996120035, 43.53339331072109, 1.0),
 (95.86155507093572, 38.22527805795094, 0.0),
 (75.01365838958247, 30.60326323428011, 0.0),
 (82.30705337399482, 76.48196330235604, 1.0),
 (69.36458875970939, 97.71869196188608, 1.0),
 (39.53833914367223, 76.03681085115882, 0.0),
 (53.9710521485623, 89.20735013750205, 1.0),
 (69.07014406283025, 52.74046973016765, 1.0),
 (67.94685547711617, 46.67857410673128, 0.0),
 (70.66150955499435, 92.92713789364831, 1.0),
 (76.97878372747498, 47.57596364975532, 1.0),
 (67.37202754570876, 42.83843832029179, 0.0),
 (89.6767757507208, 65.79936592745237, 1.0),
 (50.534788289883, 48.85581152764205, 0.0),
 (34.21206097786789, 44.20952859866288, 0.0),
 (77.9240914545704, 68.9723599933059, 1.0),
 (62.27101367004632, 69.95445795447587, 1.0),
 (80.1901807509566, 44.82162893218353, 1.0),
 (93.114388797442, 38.80067033713209, 0.0),
 (61.83020602312595, 50.25610789244621, 0.0),
 (38.78580379679423, 64.99568095539578, 0.0),
 (61.379289447425, 72.80788731317097, 1.0),
 (85.40451939411645, 57.05198397627122, 1.0),
 (52.10797973193984, 63.12762376881715, 0.0),
 (52.04540476831827, 69.43286012045222, 1.0),
 (40.23689373545111, 71.16774802184875, 0.0),
 (54.63510555424817, 52.21388588061123, 0.0),
 (33.91550010906887, 98.86943574220611, 0.0),
 (64.17698887494485, 80.90806058670817, 1.0),
 (74.78925295941542, 41.57341522824434, 0.0),
 (34.1836400264419, 75.2377203360134, 0.0),
 (83.90239366249155, 56.30804621605327, 1.0),
 (51.54772026906181, 46.85629026349976, 0.0),
 (94.44336776917852, 65.56892160559052, 1.0),
 (82.36875375713919, 40.61825515970618, 0.0),
 (51.04775177128865, 45.82270145776001, 0.0),
 (62.22267576120188, 52.06099194836679, 0.0),
 (77.19303492601364, 70.45820000180959, 1.0),
 (97.77159928000232, 86.7278223300282, 1.0),
 (62.07306379667647, 96.76882412413983, 1.0),
 (91.56497449807442, 88.696292545466, 1.0),
 (79.94481794066932, 74.16311935043758, 1.0),
 (99.2725269292572, 60.99903099844988, 1.0),
 (90.54671411399852, 43.39060180650027, 1.0),
 (34.52451385320009, 60.39634245837173, 0.0),
 (50.2864961189907, 49.80453881323059, 0.0),
 (49.58667721632031, 59.80895099453265, 0.0),
 (97.64563396007767, 68.86157272420604, 1.0),
 (32.57720016809309, 95.59854761387875, 0.0),
 (74.24869136721598, 69.82457122657193, 1.0),
 (71.7964620586338, 78.45356224515052, 1.0),
 (75.3956114656803, 85.75993667331619, 1.0),
 (35.28611281526193, 47.02051394723416, 0.0),
 (56.25381749711624, 39.26147251058019, 0.0),
 (30.05882244669796, 49.59297386723685, 0.0),
 (44.66826172480893, 66.45008614558913, 0.0),
 (66.56089447242954, 41.09209807936973, 0.0),
 (40.45755098375164, 97.53518548909936, 1.0),
 (49.07256321908844, 51.88321182073966, 0.0),
 (80.27957401466998, 92.11606081344084, 1.0),
 (66.74671856944039, 60.99139402740988, 1.0),
 (32.72283304060323, 43.30717306430063, 0.0),
 (64.0393204150601, 78.03168802018232, 1.0),
 (72.34649422579923, 96.22759296761404, 1.0),
 (60.45788573918959, 73.09499809758037, 1.0),
 (58.84095621726802, 75.85844831279042, 1.0),
 (99.82785779692128, 72.36925193383885, 1.0),
 (47.26426910848174, 88.47586499559782, 1.0),
 (50.45815980285988, 75.80985952982456, 1.0),
 (60.45555629271532, 42.50840943572217, 0.0),
 (82.22666157785568, 42.71987853716458, 0.0),
 (88.9138964166533, 69.80378889835472, 1.0),
 (94.83450672430196, 45.69430680250754, 1.0),
 (67.31925746917527, 66.58935317747915, 1.0),
 (57.23870631569862, 59.51428198012956, 1.0),
 (80.36675600171273, 90.96014789746954, 1.0),
 (68.46852178591112, 85.59430710452014, 1.0),
 (42.0754545384731, 78.84478600148043, 0.0),
 (75.47770200533905, 90.42453899753964, 1.0),
 (78.63542434898018, 96.64742716885644, 1.0),
 (52.34800398794107, 60.76950525602592, 0.0),
 (94.09433112516793, 77.15910509073893, 1.0),
 (90.44855097096364, 87.50879176484702, 1.0),
 (55.48216114069585, 35.57070347228866, 0.0),
 (74.49269241843041, 84.84513684930135, 1.0),
 (89.84580670720979, 45.35828361091658, 1.0),
 (83.48916274498238, 48.38028579728175, 1.0),
 (42.2617008099817, 87.10385094025457, 1.0),
 (99.31500880510394, 68.77540947206617, 1.0),
 (55.34001756003703, 64.9319380069486, 1.0),
 (74.77589300092767, 89.52981289513276, 1.0)]
# 标准化
x0_max = max([i[0] for i in data])
x1_max = max([i[1] for i in data])
data = [(i[0] / x0_max, i[1] / x1_max, i[2]) for i in data]

x0 = list(filter(lambda x: x[-1] == 0.0, data))
x1 = list(filter(lambda x: x[-1] == 1.0, data))

plot_x0 = [i[0] for i in x0]
plot_y0 = [i[1] for i in x0]
plot_x1 = [i[0] for i in x1]
plot_y1 = [i[1] for i in x1]

plt.plot(plot_x0, plot_y0, 'ro', label='x_0')
plt.plot(plot_x1, plot_y1, 'bo', label='x_1')
plt.legend(loc='best')
np_data = np.array(data, dtype='float32')
x_data = torch.from_numpy(np_data[:, 0:2])
y_data = torch.from_numpy(np_data[:, -1]).unsqueeze(1)
# 值越大,越接近1,反之越接近0
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
x_data = Variable(x_data)
y_data = Variable(y_data)

w = Variable(torch.randn(2, 1), requires_grad=True)
b = Variable(torch.zeros(1), requires_grad=True)
def logistic_regression(x):
    return torch.sigmoid(torch.mm(x, w) + b)
# 初始状态分类结果
w0 = w[0].data[0]
w1 = w[1].data[0]
b0 = b.data[0]

plot_x = np.arange(0.2, 1, 0.01)
_plot_x = torch.from_numpy(plot_x)

_plot_y = (-w0 * _plot_x - b0) / w1
plot_y = _plot_y.numpy()

plt.plot(plot_x, plot_y, 'g', label='line1')
plt.plot(plot_x0, plot_y0, 'ro', label='x_0')
plt.plot(plot_x1, plot_y1, 'bo', label='x_1')
 # 计算loss
def binary_loss(y_pred, y):
    logits = (y * y_pred.clamp(1e-12).log() + (1 - y) * (1 - y_pred).clamp(1e-12).log()).mean()
    return -logits
y_pred = logistic_regression(x_data)
loss = binary_loss(y_pred, y_data)
print("first: ", loss)

# 自动求导并更新参数
loss.backward()
w.data = w.data - 0.1 * w.grad.data
b.data = b.data - 0.1 * b.grad.data

# 算出一次更新之后的loss
y_pred = logistic_regression(x_data)
loss = binary_loss(y_pred, y_data)
print("second: ", loss)
first:  tensor(0.7332, grad_fn=<NegBackward>)
second:  tensor(0.7268, grad_fn=<NegBackward>)
# 使用 torch.optim 更新参数
from torch import nn
import torch.nn.functional as F

w = nn.Parameter(torch.randn(2, 1))
b = nn.Parameter(torch.zeros(1))

def logistic_regression(x):
    return F.sigmoid(torch.mm(x, w) + b)

optimizer = torch.optim.SGD([w, b], lr=1.)
# 进行 1000 次更新
import time

start = time.time()
for e in range(1000):
    # 前向传播
    y_pred = logistic_regression(x_data)
    loss = binary_loss(y_pred, y_data) # 计算 loss
    # 反向传播
    optimizer.zero_grad() # 使用优化器将梯度归 0
    loss.backward()
    optimizer.step() # 使用优化器来更新参数
    # 计算正确率
    mask = y_pred.ge(0.5).float()
    acc = (mask == y_data).sum().item() / y_data.shape[0]
    if (e + 1) % 200 == 0:
        print('epoch: {}, Loss: {:.5f}, Acc: {:.5f}'.format(e+1, loss.item(), acc))
during = time.time() - start
print()
print('During Time: {:.3f} s'.format(during))
G:\Anaconda\envs\pytorch\lib\site-packages\torch\nn\functional.py:1386: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")


epoch: 200, Loss: 0.39108, Acc: 0.92000
epoch: 400, Loss: 0.32230, Acc: 0.91000
epoch: 600, Loss: 0.28942, Acc: 0.91000
epoch: 800, Loss: 0.27000, Acc: 0.91000
epoch: 1000, Loss: 0.25711, Acc: 0.90000

During Time: 0.590 s
w0 = w[0].data[0]
w1 = w[1].data[0]
b0 = b.data[0]

plot_x = np.arange(0.2, 1, 0.01)
_plot_x = torch.from_numpy(plot_x)

_plot_y = (-w0 * _plot_x - b0) / w1
plot_y = _plot_y.numpy()

plt.plot(plot_x, plot_y, 'g', label='cutting line')
plt.plot(plot_x0, plot_y0, 'ro', label='x_0')
plt.plot(plot_x1, plot_y1, 'bo', label='x_1')
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值