Pytorch实现
上一个博客我们用传统方法实现了一个线性分类器,这里我们用pytorch实现
我们用一个单层全连接层加一个sigmoid激活函数实现
网络的计算图如下:
代码如下:
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
import torch.utils.data as d
import torch.nn.functional as func
# 创建训练样本
N = 2000
# CLASS 1
x_c1 = np.random.randn(N, 2)
x_c1 = np.add(x_c1, [10, 10])
y_c1 = np.ones((N, 1), dtype=np.double)
# CLASS 2
x_c2 = np.random.randn(N, 2)
x_c2 = np.add(x_c2, [2, 5])
y_c2 = np.zeros((N, 1), dtype=np.double)
# 生成数据
data_x = np.concatenate((x_c1, x_c2), 0)
data_y = np.concatenate((y_c1, y_c2), 0)
tensor_x = torch.tensor(data_x, dtype=torch.float)
tensor_y = torch.tensor(data_y, dtype=torch.float)
# tensor_x=data, tensor_y=lab