1.监督学习与非监督学习的区别
https://www.zhihu.com/question/27138263/answer/635004780
2.实现简单的 y = wx + b 模型
from torch.autograd import Variable
import torch
x_train = Variable(torch.linspace(1, 10, 10), requires_grad=True)
y_train = Variable(torch.linspace(10, 1, 10), requires_grad=True)
x_train, y_train
(tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], requires_grad=True),
tensor([10., 9., 8., 7., 6., 5., 4., 3., 2., 1.], requires_grad=True))
w = Variable(torch.randn(1), requires_grad=True)
b = Variable(torch.zeros(1), requires_grad=True)
def linear_model(x):
return w * x + b
def get_loss(y_, y):
return torch.mean((y_ - y) ** 2)
for i in range(1000):
y_ = linear_model(x_train)
loss = get_loss(y_, y_train)
# w,b梯度归零
if i != 0:
w.grad.data.zero_()
b.grad.data.zero_()
loss.backward()
# 更新w
w.data = w.data - 1e-2 * w.grad.data
# 更新b
b.data = b.data - 1e-2 * b.grad.data
if i % 200 == 0:
print('epoch: {}, loss: {}'.format(i, loss.data))
epoch: 0, loss: 43.67679977416992
epoch: 200, loss: 4.732393264770508
epoch: 400, loss: 0.8790384531021118
epoch: 600, loss: 0.16328111290931702
epoch: 800, loss: 0.030329588800668716
3.Logistic回归模型
返回0-1之间的概率
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
data = [(34.62365962451697, 78.0246928153624, 0.0),
(30.28671076822607, 43.89499752400101, 0.0),
(35.84740876993872, 72.90219802708364, 0.0),
(60.18259938620976, 86.30855209546826, 1.0),
(79.0327360507101, 75.3443764369103, 1.0),
(45.08327747668339, 56.3163717815305, 0.0),
(61.10666453684766, 96.51142588489624, 1.0),
(75.02474556738889, 46.55401354116538, 1.0),
(76.09878670226257, 87.42056971926803, 1.0),
(84.43281996120035, 43.53339331072109, 1.0),
(95.86155507093572, 38.22527805795094, 0.0),
(75.01365838958247, 30.60326323428011, 0.0),
(82.30705337399482, 76.48196330235604, 1.0),
(69.36458875970939, 97.71869196188608, 1.0),
(39.53833914367223, 76.03681085115882, 0.0),
(53.9710521485623, 89.20735013750205, 1.0),
(69.07014406283025, 52.74046973016765, 1.0),
(67.94685547711617, 46.67857410673128, 0.0),
(70.66150955499435, 92.92713789364831, 1.0),
(76.97878372747498, 47.57596364975532, 1.0),
(67.37202754570876, 42.83843832029179, 0.0),
(89.6767757507208, 65.79936592745237, 1.0),
(50.534788289883, 48.85581152764205, 0.0),
(34.21206097786789, 44.20952859866288, 0.0),
(77.9240914545704, 68.9723599933059, 1.0),
(62.27101367004632, 69.95445795447587, 1.0),
(80.1901807509566, 44.82162893218353, 1.0),
(93.114388797442, 38.80067033713209, 0.0),
(61.83020602312595, 50.25610789244621, 0.0),
(38.78580379679423, 64.99568095539578, 0.0),
(61.379289447425, 72.80788731317097, 1.0),
(85.40451939411645, 57.05198397627122, 1.0),
(52.10797973193984, 63.12762376881715, 0.0),
(52.04540476831827, 69.43286012045222, 1.0),
(40.23689373545111, 71.16774802184875, 0.0),
(54.63510555424817, 52.21388588061123, 0.0),
(33.91550010906887, 98.86943574220611, 0.0),
(64.17698887494485, 80.90806058670817, 1.0),
(74.78925295941542, 41.57341522824434, 0.0),
(34.1836400264419, 75.2377203360134, 0.0),
(83.90239366249155, 56.30804621605327, 1.0),
(51.54772026906181, 46.85629026349976, 0.0),
(94.44336776917852, 65.56892160559052, 1.0),
(82.36875375713919, 40.61825515970618, 0.0),
(51.04775177128865, 45.82270145776001, 0.0),
(62.22267576120188, 52.06099194836679, 0.0),
(77.19303492601364, 70.45820000180959, 1.0),
(97.77159928000232, 86.7278223300282, 1.0),
(62.07306379667647, 96.76882412413983, 1.0),
(91.56497449807442, 88.696292545466, 1.0),
(79.94481794066932, 74.16311935043758, 1.0),
(99.2725269292572, 60.99903099844988, 1.0),
(90.54671411399852, 43.39060180650027, 1.0),
(34.52451385320009, 60.39634245837173, 0.0),
(50.2864961189907, 49.80453881323059, 0.0),
(49.58667721632031, 59.80895099453265, 0.0),
(97.64563396007767, 68.86157272420604, 1.0),
(32.57720016809309, 95.59854761387875, 0.0),
(74.24869136721598, 69.82457122657193, 1.0),
(71.7964620586338, 78.45356224515052, 1.0),
(75.3956114656803, 85.75993667331619, 1.0),
(35.28611281526193, 47.02051394723416, 0.0),
(56.25381749711624, 39.26147251058019, 0.0),
(30.05882244669796, 49.59297386723685, 0.0),
(44.66826172480893, 66.45008614558913, 0.0),
(66.56089447242954, 41.09209807936973, 0.0),
(40.45755098375164, 97.53518548909936, 1.0),
(49.07256321908844, 51.88321182073966, 0.0),
(80.27957401466998, 92.11606081344084, 1.0),
(66.74671856944039, 60.99139402740988, 1.0),
(32.72283304060323, 43.30717306430063, 0.0),
(64.0393204150601, 78.03168802018232, 1.0),
(72.34649422579923, 96.22759296761404, 1.0),
(60.45788573918959, 73.09499809758037, 1.0),
(58.84095621726802, 75.85844831279042, 1.0),
(99.82785779692128, 72.36925193383885, 1.0),
(47.26426910848174, 88.47586499559782, 1.0),
(50.45815980285988, 75.80985952982456, 1.0),
(60.45555629271532, 42.50840943572217, 0.0),
(82.22666157785568, 42.71987853716458, 0.0),
(88.9138964166533, 69.80378889835472, 1.0),
(94.83450672430196, 45.69430680250754, 1.0),
(67.31925746917527, 66.58935317747915, 1.0),
(57.23870631569862, 59.51428198012956, 1.0),
(80.36675600171273, 90.96014789746954, 1.0),
(68.46852178591112, 85.59430710452014, 1.0),
(42.0754545384731, 78.84478600148043, 0.0),
(75.47770200533905, 90.42453899753964, 1.0),
(78.63542434898018, 96.64742716885644, 1.0),
(52.34800398794107, 60.76950525602592, 0.0),
(94.09433112516793, 77.15910509073893, 1.0),
(90.44855097096364, 87.50879176484702, 1.0),
(55.48216114069585, 35.57070347228866, 0.0),
(74.49269241843041, 84.84513684930135, 1.0),
(89.84580670720979, 45.35828361091658, 1.0),
(83.48916274498238, 48.38028579728175, 1.0),
(42.2617008099817, 87.10385094025457, 1.0),
(99.31500880510394, 68.77540947206617, 1.0),
(55.34001756003703, 64.9319380069486, 1.0),
(74.77589300092767, 89.52981289513276, 1.0)]
# 标准化
x0_max = max([i[0] for i in data])
x1_max = max([i[1] for i in data])
data = [(i[0] / x0_max, i[1] / x1_max, i[2]) for i in data]
x0 = list(filter(lambda x: x[-1] == 0.0, data))
x1 = list(filter(lambda x: x[-1] == 1.0, data))
plot_x0 = [i[0] for i in x0]
plot_y0 = [i[1] for i in x0]
plot_x1 = [i[0] for i in x1]
plot_y1 = [i[1] for i in x1]
plt.plot(plot_x0, plot_y0, 'ro', label='x_0')
plt.plot(plot_x1, plot_y1, 'bo', label='x_1')
plt.legend(loc='best')
np_data = np.array(data, dtype='float32')
x_data = torch.from_numpy(np_data[:, 0:2])
y_data = torch.from_numpy(np_data[:, -1]).unsqueeze(1)
# 值越大,越接近1,反之越接近0
def sigmoid(x):
return 1 / (1 + np.exp(-x))
x_data = Variable(x_data)
y_data = Variable(y_data)
w = Variable(torch.randn(2, 1), requires_grad=True)
b = Variable(torch.zeros(1), requires_grad=True)
def logistic_regression(x):
return torch.sigmoid(torch.mm(x, w) + b)
# 初始状态分类结果
w0 = w[0].data[0]
w1 = w[1].data[0]
b0 = b.data[0]
plot_x = np.arange(0.2, 1, 0.01)
_plot_x = torch.from_numpy(plot_x)
_plot_y = (-w0 * _plot_x - b0) / w1
plot_y = _plot_y.numpy()
plt.plot(plot_x, plot_y, 'g', label='line1')
plt.plot(plot_x0, plot_y0, 'ro', label='x_0')
plt.plot(plot_x1, plot_y1, 'bo', label='x_1')
# 计算loss
def binary_loss(y_pred, y):
logits = (y * y_pred.clamp(1e-12).log() + (1 - y) * (1 - y_pred).clamp(1e-12).log()).mean()
return -logits
y_pred = logistic_regression(x_data)
loss = binary_loss(y_pred, y_data)
print("first: ", loss)
# 自动求导并更新参数
loss.backward()
w.data = w.data - 0.1 * w.grad.data
b.data = b.data - 0.1 * b.grad.data
# 算出一次更新之后的loss
y_pred = logistic_regression(x_data)
loss = binary_loss(y_pred, y_data)
print("second: ", loss)
first: tensor(0.7332, grad_fn=<NegBackward>)
second: tensor(0.7268, grad_fn=<NegBackward>)
# 使用 torch.optim 更新参数
from torch import nn
import torch.nn.functional as F
w = nn.Parameter(torch.randn(2, 1))
b = nn.Parameter(torch.zeros(1))
def logistic_regression(x):
return F.sigmoid(torch.mm(x, w) + b)
optimizer = torch.optim.SGD([w, b], lr=1.)
# 进行 1000 次更新
import time
start = time.time()
for e in range(1000):
# 前向传播
y_pred = logistic_regression(x_data)
loss = binary_loss(y_pred, y_data) # 计算 loss
# 反向传播
optimizer.zero_grad() # 使用优化器将梯度归 0
loss.backward()
optimizer.step() # 使用优化器来更新参数
# 计算正确率
mask = y_pred.ge(0.5).float()
acc = (mask == y_data).sum().item() / y_data.shape[0]
if (e + 1) % 200 == 0:
print('epoch: {}, Loss: {:.5f}, Acc: {:.5f}'.format(e+1, loss.item(), acc))
during = time.time() - start
print()
print('During Time: {:.3f} s'.format(during))
G:\Anaconda\envs\pytorch\lib\site-packages\torch\nn\functional.py:1386: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
epoch: 200, Loss: 0.39108, Acc: 0.92000
epoch: 400, Loss: 0.32230, Acc: 0.91000
epoch: 600, Loss: 0.28942, Acc: 0.91000
epoch: 800, Loss: 0.27000, Acc: 0.91000
epoch: 1000, Loss: 0.25711, Acc: 0.90000
During Time: 0.590 s
w0 = w[0].data[0]
w1 = w[1].data[0]
b0 = b.data[0]
plot_x = np.arange(0.2, 1, 0.01)
_plot_x = torch.from_numpy(plot_x)
_plot_y = (-w0 * _plot_x - b0) / w1
plot_y = _plot_y.numpy()
plt.plot(plot_x, plot_y, 'g', label='cutting line')
plt.plot(plot_x0, plot_y0, 'ro', label='x_0')
plt.plot(plot_x1, plot_y1, 'bo', label='x_1')