从 0 实现 logistic 回归:
(1)导入所需要的包
- import torch
- from IPython import display
- from matplotlib import pyplot as plt
- import numpy as np
- import random
(2)生成数据集
- #生成数据集
- n_data = torch.ones(50, 2)
- x1 = torch.normal(2 * n_data, 1)
- y1 = torch.zeros(50)
- x2 = torch.normal(-2 * n_data, 1)
- y2 = torch.ones(50)
- x = torch.cat((x1, x2), 0).type(torch.FloatTensor)
- y = torch.cat((y1, y2), 0).type(torch.FloatTensor)
(3)定义迭代器
- #定义迭代器
- def data_iter(batch_size, features, labels):
- num_examples = len(features)
- indices = list(range(num_examples))
- random.shuffle(indices)
- for i in range(0, num_examples, batch_size):
- j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)])
- yield features.index_select(0, j), labels.index_select(0, j)
(4)初始化学习参数
- #初始化学习参数
- w = torch.tensor(np.random.normal(0, 0.01, (2, 1)), dtype=torch.float32)
- b = torch.zeros(1, dtype=torch.float32)
- w.requires_grad_(requires_grad=True)
- b.requires_grad_(requires_grad=True)
(5)定义模型中的函数
- #定义线性判别函数
- def linreg(X, w, b):
- return torch.mm(X, w) + b
- #定义Logistic决策函数
- def logistic(x,w,b):
- return 1/(1+torch.exp(-1*torch.mm(x,w)+b))
- #定义交叉熵损失函数
- def bce_loss(y_hat,y):
- return -1 * (y.view(len(y_hat),-1) * torch.log10(y_hat) + (1 - y.view(len(y_hat),-1)) * torch.log10(1 - y_hat))
- #定义平方和损失函数
- def squared_loss(y_hat, y):
- return (y_hat - y.view(y_hat.size())) ** 2 / 2
- #定义梯度下降优化函数
- def sgd(params, lr, batch_size):
- for param in params:
- param.data -= lr * param.grad / batch_size
(6)开始训练并计算每轮损失
- #开始训练并计算每轮损失
- lr = 0.03
- num_epochs = 20
- batch_size = 10
- net = logistic
- #loss = torch.nn.BCELoss()
- loss = bce_loss
- for epoch in range(num_epochs):
- for X, Y in data_iter(batch_size, x, y):
- l = loss(net(X, w, b), Y).sum()
- l.backward()
- sgd([w, b], lr, batch_size)
- w.grad.data.zero_()
- b.grad.data.zero_()
- train_l = loss(net(x, w, b), y)
- print('epoch %d, loss %f' % (epoch + 1, train_l.mean().item()))#第一次训练后全部训练集的损失的均值
- acc_sum = (net(x, w, b).ge(0.5).float().view(-1, 1) == y.view(-1, 1)).float().sum().item()
- print('accuracy %f' % (acc_sum / y.shape[0]))
(7)输出优化后的参数
- print('\n', w)
- print('\n', b)