目录
前言
前馈神经网络(feedforward neural network,FNN),简称前馈网络,是人工神经网络的一种。前馈神经网络采用一种单向多层结构。其中每一层包含若干个神经元。在此种神经网络中,各神经元可以接收前一层神经元的信号,并产生输出到下一层。第0层叫输入层,最后一层叫输出层,其他中间层叫做隐含层(或隐藏层、隐层)。隐层可以是一层。也可以是多层 。
一、神经元
神经网络的基本组成单元为带有非线性激活函数的神经元,其结构如如下图所示。神经元是对生物神经元的结构和特性的一种简化建模,接收一组输入信号并产生输出。
1、净活性值
假设一个神经元接收的输入为,其权重向量为
,神经元所获得的输入信号,即净活性值
的计算方法为
,其中
为偏置
为了提高预测样本的效率,我们通常会将个样本归为一组进行成批地预测。
,其中
为
个样本的特征矩阵,
为
个预测值组成的列向量。
使用pytorch计算一组输入的净活性值:
import torch
# 2个特征数为5的样本
X = torch.rand([2, 5])
#参数
w = torch.rand([5, 1])
b = torch.rand([1, 1])
# 使用'torch.matmul'实现矩阵相乘
z = torch.matmul(X, w) + b
print("input X:", X)
print("weight w:", w, "\nbias b:", b)
print("output z:", z)
执行结果:
思考题:
加权求和与仿射变换之间有什么区别和联系?
在我看来加权求和在本质是就是一个线性变换,线性变换 ,变换前是直线,变换后依然是直线; 直线比例保持不变; 变换前是原点,变换后依然是原点。而仿射变换是线性变换接了一个平移,从一个向量空间进入另一个向量空间计算; 变换前是直线,变换后依然是直线; 直线比例保持不变。
2、激活函数
净活性值再经过一个非线性函数
后,得到神经元的活性值
。
,
激活函数通常为非线性函数,可以增强神经网络的表示能力和学习能力。常用的激活函数有S型函数和ReLU函数。
2.1 Sigmoid 型函数
Sigmoid 型函数是指一类S型曲线函数,为两端饱和函数。常用的 Sigmoid 型函数有 Logistic 函数和 Tanh 函数,其数学表达式为
Logistic 函数:
.
Tanh 函数:
.
实现并可视化“Logistic函数、Tanh函数”
实现如下:
import torch
import matplotlib.pyplot as plt
# Logistic函数
def logistic(z):
return 1.0 / (1.0 + torch.exp(-z))
# Tanh函数
def tanh(z):
return (torch.exp(z) - torch.exp(-z)) / (torch.exp(z) + torch.exp(-z))
# 在[-10,10]的范围内生成10000个输入值,用于绘制函数曲线
z = torch.linspace(-10, 10, 10000)
plt.figure()
plt.plot(z.tolist(), logistic(z).tolist(), color='#e4007f', label="Logistic Function")
plt.plot(z.tolist(), tanh(z).tolist(), color='#f19ec2', linestyle ='--', label="Tanh Function")
ax = plt.gca() # 获取轴,默认有4个
# 隐藏两个轴,通过把颜色设置成none
ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')
# 调整坐标轴位置
ax.spines['left'].set_position(('data',0))
ax.spines['bottom'].set_position(('data',0))
plt.legend(loc='lower right', fontsize='large')
plt.savefig('fw-logistic-tanh.pdf')
plt.show()
2.2 ReLU型函数
常见的ReLU函数有ReLU和带泄露的ReLU(Leaky ReLU),数学表达式分别为:
,
,
其中为超参数。
实现并可视化可视化“ReLU、带泄露的ReLU的函数”
实现如下:
def relu(z):
return torch.maximum(z, torch.tensor(0.))
def leaky_relu(z, negative_slope=0.1):
# 当前版本torch暂不支持直接将bool类型转成int类型,因此调用了torch的cast函数来进行显式转换
a1 = (torch.can_cast((z > 0).dtype, to=torch.float32) * z)
a2 = (torch.can_cast((z <= 0).dtype, to=torch.float32) * (negative_slope * z))
return a1 + a2
# 在[-10,10]的范围内生成一系列的输入值,用于绘制relu、leaky_relu的函数曲线
z = torch.linspace(-10, 10, 10000)
plt.figure()
plt.plot(z.tolist(), relu(z).tolist(), color="#e4007f", label="ReLU Function")
plt.plot(z.tolist(), leaky_relu(z).tolist(), color="#f19ec2", linestyle="--", label="LeakyReLU Function")
ax = plt.gca()
ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')
ax.spines['left'].set_position(('data',0))
ax.spines['bottom'].set_position(('data',0))
plt.legend(loc='upper left', fontsize='large')
plt.savefig('fw-relu-leakyrelu.pdf')
plt.show()
二、基于前馈神经网络的二分类任务
前馈神经网络的网络结构如下图所示。每一层获取前一层神经元的活性值,并重复上述计算得到该层的活性值,传入到下一层。整个网络中无反馈,信号从输入层向输出层逐层的单向传播,得到网络最后的输出 。
1、数据集构建
使用上次实验中构建的二分类数据集:Moon1000数据集,其中训练集640条、验证集160条、测试集200条。该数据集的数据是从两个带噪音的弯月形状数据分布中采样得到,每个样本包含2个特征。
from nndl2.dataset import make_moons
n_samples = 1000
X, y = make_moons(n_samples=n_samples, shuffle=True, noise=0.15)
num_train = 640
num_dev = 160
num_test = 200
X_train, y_train = X[:num_train], y[:num_train]
X_dev, y_dev = X[num_train:num_train + num_dev], y[num_train:num_train + num_dev]
X_test, y_test = X[num_train + num_dev:], y[num_train + num_dev:]
y_train = y_train.reshape([-1,1])
y_dev = y_dev.reshape([-1,1])
y_test = y_test.reshape([-1,1])
2、模型构建
为了更高效的构建前馈神经网络,先定义每一层的算子,然后再通过算子组合构建整个前馈神经网络。
假设网络的第层的输入为第
层的神经元活性值
,经过一个仿射变换,得到该层神经元的净活性值
,再输入到激活函数得到该层神经元的活性值
。
在实践中,为了提高模型的处理效率,通常将个样本归为一组进行成批地计算。假设网络第ll层的输入为
,其中每一行为一个样本,则前馈网络中第
层的计算公式为
其中为NN个样本第ll层神经元的净活性值,
为NN个样本第ll层神经元的活性值,
为第
层的权重矩阵,
为第
层的偏置。
2.1 线性算子
from nndl.op import Op
# 实现线性层算子
class Linear(Op):
def __init__(self, input_size, output_size, name, weight_init=torch.normal, bias_init=torch.zeros):
self.params = {}
# 初始化权重
self.params['W'] = weight_init(0,1,[input_size,output_size])
# 初始化偏置
self.params['b'] = bias_init([1,output_size])
self.inputs = None
self.name = name
def forward(self, inputs):
self.inputs = inputs
outputs = torch.matmul(self.inputs, self.params['W']) + self.params['b']
return outputs
2.2 Logistic算子
class Logistic(Op):
def __init__(self):
self.inputs = None
self.outputs = None
def forward(self, inputs):
outputs = 1.0 / (1.0 + torch.exp(-inputs))
self.outputs = outputs
return outputs
2.3 层的串行组合
class Model_MLP_L2(Op):
def __init__(self, input_size, hidden_size, output_size):
self.fc1 = Linear(input_size, hidden_size, name="fc1")
self.act_fn1 = Logistic()
self.fc2 = Linear(hidden_size, output_size, name="fc2")
self.act_fn2 = Logistic()
def __call__(self, X):
return self.forward(X)
def forward(self, X):
z1 = self.fc1(X)
a1 = self.act_fn1(z1)
z2 = self.fc2(a1)
a2 = self.act_fn2(z2)
return a2
输入层维度为5,隐藏层维度为10,输出层维度为1。
并随机生成一条长度为5的数据输入两层神经网络,观察输出结果。
# 实例化模型
model = Model_MLP_L2(input_size=5, hidden_size=10, output_size=1)
# 随机生成1条长度为5的数据
X = torch.rand([1, 5])
result = model(X)
print ("result: ", result)
执行结果:
result: tensor([[0.8152]])
3、损失函数
二分类交叉熵损失函数
4、模型优化
神经网络的参数主要是通过梯度下降法进行优化的,因此需要计算最终损失对每个参数的梯度。
由于神经网络的层数通常比较深,其梯度计算和上一章中的线性分类模型的不同的点在于:线性模型通常比较简单可以直接计算梯度,而神经网络相当于一个复合函数,需要利用链式法则进行反向传播来计算梯度。
4.1 反向传播算法
4.2 损失函数
代码实现如下:
class BinaryCrossEntropyLoss(Op):
def __init__(self, model):
self.predicts = None
self.labels = None
self.num = None
self.model = model
def __call__(self, predicts, labels):
return self.forward(predicts, labels)
def forward(self, predicts, labels):
self.predicts = predicts
self.labels = labels
self.num = self.predicts.shape[0]
loss = -1. / self.num * (torch.matmul(self.labels.t(), torch.log(self.predicts))
+ torch.matmul((1-self.labels.t()), torch.log(1-self.predicts)))
loss = torch.squeeze(loss, axis=1)
return loss
def backward(self):
# 计算损失函数对模型预测的导数
loss_grad_predicts = -1.0 * (self.labels / self.predicts -
(1 - self.labels) / (1 - self.predicts)) / self.num
# 梯度反向传播
self.model.backward(loss_grad_predicts)
4.3 Logistic算子
class Logistic(Op):
def __init__(self):
self.inputs = None
self.outputs = None
self.params = None
def forward(self, inputs):
outputs = 1.0 / (1.0 + torch.exp(-inputs))
self.outputs = outputs
return outputs
def backward(self, grads):
# 计算Logistic激活函数对输入的导数
outputs_grad_inputs = torch.multiply(self.outputs, (1.0 - self.outputs))
return torch.multiply(grads,outputs_grad_inputs)
4.4 线性层
class Linear(Op):
def __init__(self, input_size, output_size, name, weight_init=torch.normal, bias_init=torch.zeros):
self.params = {}
self.params['W'] = weight_init(0,1,[input_size, output_size])
self.params['b'] = bias_init([1, output_size])
self.inputs = None
self.grads = {}
self.name = name
def forward(self, inputs):
self.inputs = inputs
outputs = torch.matmul(self.inputs, self.params['W']) + self.params['b']
return outputs
def backward(self, grads):
self.grads['W'] = torch.matmul(self.inputs.T, grads)
self.grads['b'] = torch.sum(grads, axis=0)
# 线性层输入的梯度
return torch.matmul(grads, self.params['W'].T)
4.5 整个网络
实现完整的两层神经网络的前向和反向计算。
class Model_MLP_L2(Op):
def __init__(self, input_size, hidden_size, output_size):
# 线性层
self.fc1 = Linear(input_size, hidden_size, name="fc1")
# Logistic激活函数层
self.act_fn1 = Logistic()
self.fc2 = Linear(hidden_size, output_size, name="fc2")
self.act_fn2 = Logistic()
self.layers = [self.fc1, self.act_fn1, self.fc2, self.act_fn2]
def __call__(self, X):
return self.forward(X)
# 前向计算
def forward(self, X):
z1 = self.fc1(X)
a1 = self.act_fn1(z1)
z2 = self.fc2(a1)
a2 = self.act_fn2(z2)
return a2
# 反向计算
def backward(self, loss_grad_a2):
loss_grad_z2 = self.act_fn2.backward(loss_grad_a2)
loss_grad_a1 = self.fc2.backward(loss_grad_z2)
loss_grad_z1 = self.act_fn1.backward(loss_grad_a1)
loss_grad_inputs = self.fc1.backward(loss_grad_z1)
4.6 优化器
在计算好神经网络参数的梯度之后,我们将梯度下降法中参数的更新过程实现在优化器中。
from abc import abstractmethod
class Optimizer(object):
def __init__(self, init_lr, model): # 优化器类初始化
#初始化学习率,用于参数更新的计算
self.init_lr = init_lr
#指定优化器需要优化的模型
self.model = model
@abstractmethod
def step(self):
"""
定义每次迭代如何更新参数
"""
pass
class BatchGD(Optimizer):
def __init__(self, init_lr, model):
super(BatchGD, self).__init__(init_lr=init_lr, model=model)
def step(self):
# 参数更新
for layer in self.model.layers: # 遍历所有层
if isinstance(layer.params, dict):
for key in layer.params.keys():
layer.params[key] = layer.params[key] - self.init_lr * layer.grads[key]
5、完善Runner类:RunnerV2_1
上一个实验中所实现的Runner类主要是针对比肩简单的模型,而在本次实验中,模型由多个算子组成,因此需要加入一些功能:
- 支持自定义算子的梯度计算,在训练过程中调用
self.loss_fn.backward()
从损失函数开始反向计算梯度; - 每层的模型保存和加载,将每一层的参数分别进行保存和加载。
import os
os.getcwd()
class RunnerV2_1(object):
def __init__(self, model, optimizer, metric, loss_fn, **kwargs):
self.model = model
self.optimizer = optimizer
self.loss_fn = loss_fn
self.metric = metric
# 记录训练过程中的评估指标变化情况
self.train_scores = []
self.dev_scores = []
# 记录训练过程中的评价指标变化情况
self.train_loss = []
self.dev_loss = []
def train(self, train_set, dev_set, **kwargs):
# 传入训练轮数,如果没有传入值则默认为0
num_epochs = kwargs.get("num_epochs", 0)
# 传入log打印频率,如果没有传入值则默认为100
log_epochs = kwargs.get("log_epochs", 100)
# 传入模型保存路径
save_dir = kwargs.get("save_dir", None)
# 记录全局最优指标
best_score = 0
# 进行num_epochs轮训练
for epoch in range(num_epochs):
X, y = train_set
# 获取模型预测
logits = self.model(X)
# 计算交叉熵损失
trn_loss = self.loss_fn(logits, y) # return a tensor
self.train_loss.append(trn_loss.item())
# 计算评估指标
trn_score = self.metric(logits, y).item()
self.train_scores.append(trn_score)
self.loss_fn.backward()
# 参数更新
self.optimizer.step()
dev_score, dev_loss = self.evaluate(dev_set)
# 如果当前指标为最优指标,保存该模型
if dev_score > best_score:
print(f"[Evaluate] best accuracy performence has been updated: {best_score:.5f} --> {dev_score:.5f}")
best_score = dev_score
if save_dir:
self.save_model(save_dir)
if log_epochs and epoch % log_epochs == 0:
print(f"[Train] epoch: {epoch}/{num_epochs}, loss: {trn_loss.item()}")
def evaluate(self, data_set):
X, y = data_set
# 计算模型输出
logits = self.model(X)
# 计算损失函数
loss = self.loss_fn(logits, y).item()
self.dev_loss.append(loss)
# 计算评估指标
score = self.metric(logits, y).item()
self.dev_scores.append(score)
return score, loss
def predict(self, X):
return self.model(X)
def save_model(self, save_dir):
# 对模型每层参数分别进行保存,保存文件名称与该层名称相同
for layer in self.model.layers: # 遍历所有层
if isinstance(layer.params, dict):
torch.save(layer.params, os.path.join(save_dir, layer.name+".pt"))
def load_model(self, model_dir):
# 获取所有层参数名称和保存路径之间的对应关系
model_file_names = os.listdir(model_dir)
name_file_dict = {}
for file_name in model_file_names:
name = file_name.replace(".pt","")
name_file_dict[name] = os.path.join(model_dir, file_name)
# 加载每层参数
for layer in self.model.layers: # 遍历所有层
if isinstance(layer.params, dict):
name = layer.name
file_path = name_file_dict[name]
layer.params = torch.load(file_path)
6、模型训练
使用训练集和验证集进行模型训练,共训练2000个epoch。
import os
os.getcwd()
#3.1.5评价指标
def accuracy(preds, labels):
print(preds)
# 判断是二分类任务还是多分类任务,preds.shape[1]=1时为二分类任务,preds.shape[1]>1时为多分类任务
if preds.shape[1] == 1:
# 二分类时,判断每个概率值是否大于0.5,当大于0.5时,类别为1,否则类别为0
# 使用'torch.can_cast'将preds的数据类型转换为float32类型
preds = torch.can_cast((preds>=0.5).dtype,to=torch.float32)
else:
# 多分类时,使用'torch.argmax'计算最大元素索引作为类别
preds = torch.argmax(preds,dim=1)
torch.can_cast(preds.dtype,torch.int32)
return torch.mean(torch.as_tensor((preds == labels), dtype=torch.float32))
torch.manual_seed(123)
epoch_num = 1000
model_saved_dir = "D:\\model"
# 输入层维度为2
input_size = 2
# 隐藏层维度为5
hidden_size = 5
# 输出层维度为1
output_size = 1
# 定义网络
model = Model_MLP_L2(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# 损失函数
loss_fn = BinaryCrossEntropyLoss(model)
# 优化器
learning_rate = 2.0
optimizer = BatchGD(learning_rate, model)
# 评价方法
metric = accuracy
# 实例化RunnerV2_1类,并传入训练配置
runner = RunnerV2_1(model, optimizer, metric, loss_fn)
runner.train([X_train, y_train], [X_dev, y_dev], num_epochs=epoch_num, log_epochs=50, save_dir=model_saved_dir)
可视化观察训练集与验证集的损失函数变化情况。
print(runner.train_loss)
# 打印训练集和验证集的损失
plt.figure()
plt.plot(range(epoch_num), runner.train_loss, color="#e4007f", label="Train loss")
plt.plot(range(epoch_num), runner.dev_loss, color="#f19ec2", linestyle='--', label="Dev loss")
plt.xlabel("epoch", fontsize='large')
plt.ylabel("loss", fontsize='large')
plt.legend(fontsize='x-large')
plt.savefig('fw-loss2.pdf')
plt.show()
执行结果:
[0.7783548831939697, 0.7278994917869568, 0.7125570774078369, 0.7000342607498169, 0.6891533136367798, 0.679099440574646, 0.6692914366722107, 0.6593152284622192, 0.6488897204399109, 0.6378504633903503, 0.6261442303657532, 0.6138232350349426, 0.6010352373123169, 0.5880035758018494, 0.5749973058700562, 0.5622959136962891, 0.5501543283462524, 0.5387770533561707, 0.5283035635948181, 0.5188069939613342, 0.5103022456169128, 0.5027592778205872, 0.4961172044277191, 0.490297794342041, 0.4852154850959778, 0.4807843863964081, 0.4769230782985687, 0.47355756163597107, 0.470621258020401, 0.4680560231208801, 0.4658114016056061, 0.46384397149086, 0.462116539478302, 0.4605972468852997, 0.45925894379615784, 0.4580782353878021, 0.4570351541042328, 0.4561125338077545, 0.4552955627441406, 0.45457133650779724, 0.45392876863479614, 0.45335817337036133, 0.4528510272502899, 0.4523998200893402, 0.4519980847835541, 0.45164012908935547, 0.45132094621658325, 0.45103588700294495, 0.4507812559604645, 0.45055341720581055, 0.45034924149513245, 0.45016607642173767, 0.4500015676021576, 0.4498535096645355, 0.4497201442718506, 0.44959965348243713, 0.44949063658714294, 0.44939175248146057, 0.44930192828178406, 0.4492201805114746, 0.4491454064846039, 0.4490770399570465, 0.44901424646377563, 0.44895657896995544, 0.44890329241752625, 0.4488539695739746, 0.4488081932067871, 0.44876575469970703, 0.448726087808609, 0.4486890435218811, 0.44865426421165466, 0.44862157106399536, 0.44859081506729126, 0.4485616683959961, 0.448534220457077, 0.44850802421569824, 0.44848328828811646, 0.4484596252441406, 0.44843703508377075, 0.44841551780700684, 0.44839486479759216, 0.44837504625320435, 0.4483560621738434, 0.4483378529548645, 0.448320209980011, 0.44830322265625, 0.44828692078590393, 0.44827109575271606, 0.44825586676597595, 0.44824114441871643, 0.4482267498970032, 0.4482129216194153, 0.4481995701789856, 0.448186457157135, 0.4481738209724426, 0.4481615126132965, 0.44814959168434143, 0.4481379985809326, 0.44812676310539246, 0.448115736246109, 0.4481050670146942, 0.44809457659721375, 0.44808444380760193, 0.4480746388435364, 0.4480649530887604, 0.44805556535720825, 0.4480464458465576, 0.44803744554519653, 0.4480287730693817, 0.4480202794075012, 0.4480118751525879, 0.44800376892089844, 0.4479958713054657, 0.4479881227016449, 0.4479805529117584, 0.44797325134277344, 0.44796591997146606, 0.44795891642570496, 0.447952002286911, 0.4479452669620514, 0.4479386508464813, 0.4479321539402008, 0.447925865650177, 0.44791969656944275, 0.44791364669799805, 0.4479076564311981, 0.44790181517601013, 0.44789621233940125, 0.44789057970046997, 0.447885125875473, 0.4478797912597656, 0.44787460565567017, 0.4478694498538971, 0.4478643834590912, 0.4478594958782196, 0.4478547275066376, 0.44784989953041077, 0.4478452205657959, 0.4478406012058258, 0.4478362202644348, 0.44783177971839905, 0.44782742857933044, 0.4478231966495514, 0.4478190541267395, 0.44781503081321716, 0.44781094789505005, 0.44780704379081726, 0.44780316948890686, 0.447799414396286, 0.4477955996990204, 0.4477919042110443, 0.4477883279323578, 0.44778481125831604, 0.4477812945842743, 0.4477778375148773, 0.4477744698524475, 0.4477711617946625, 0.44776782393455505, 0.44776469469070435, 0.4477614462375641, 0.44775840640068054, 0.4477553069591522, 0.44775229692459106, 0.4477493464946747, 0.4477464258670807, 0.4477435052394867, 0.4477406442165375, 0.44773778319358826, 0.44773513078689575, 0.44773244857788086, 0.4477296769618988, 0.44772711396217346, 0.44772443175315857, 0.4477218687534332, 0.44771939516067505, 0.44771692156791687, 0.4477144181728363, 0.4477120041847229, 0.4477095603942871, 0.44770726561546326, 0.447704941034317, 0.44770264625549316, 0.4477003216743469, 0.44769811630249023, 0.44769594073295593, 0.44769373536109924, 0.44769158959388733, 0.4476894438266754, 0.4476873576641083, 0.4476853013038635, 0.4476832449436188, 0.4476812481880188, 0.4476792514324188, 0.44767728447914124, 0.44767528772354126, 0.44767341017723083, 0.44767147302627563, 0.4476695656776428, 0.4476676881313324, 0.44766589999198914, 0.4476639926433563, 0.44766226410865784, 0.4476604461669922, 0.4476586878299713, 0.44765692949295044, 0.4476552903652191, 0.447653591632843, 0.44765186309814453, 0.4476501941680908, 0.4476485848426819, 0.4476469159126282, 0.4476453363895416, 0.4476436674594879, 0.44764214754104614, 0.4476405680179596, 0.44763898849487305, 0.4476374685764313, 0.4476359486579895, 0.4476345181465149, 0.44763293862342834, 0.44763150811195374, 0.44762998819351196, 0.4476286470890045, 0.4476272165775299, 0.4476257860660553, 0.4476242959499359, 0.44762298464775085, 0.44762155413627625, 0.4476202130317688, 0.44761887192726135, 0.4476175010204315, 0.44761618971824646, 0.4476148784160614, 0.4476136267185211, 0.4476121962070465, 0.4476109445095062, 0.44760972261428833, 0.44760844111442566, 0.44760724902153015, 0.4476059377193451, 0.4476047456264496, 0.4476035237312317, 0.4476023316383362, 0.4476011395454407, 0.4475998878479004, 0.44759875535964966, 0.44759756326675415, 0.44759640097618103, 0.4475953280925751, 0.44759416580200195, 0.4475930333137512, 0.4475918710231781, 0.44759073853492737, 0.447589635848999, 0.4475885331630707, 0.4475875496864319, 0.4475863575935364, 0.4475853145122528, 0.4475843012332916, 0.4475831985473633, 0.4475822150707245, 0.44758111238479614, 0.44758009910583496, 0.4475790560245514, 0.4475780427455902, 0.4475770890712738, 0.44757604598999023, 0.44757506251335144, 0.4475741386413574, 0.44757309556007385, 0.44757214188575745, 0.44757118821144104, 0.44757017493247986, 0.4475692808628082, 0.4475683271884918, 0.4475673735141754, 0.447566419839859, 0.4475655257701874, 0.44756457209587097, 0.44756370782852173, 0.4475627541542053, 0.4475618302822113, 0.44756102561950684, 0.44756007194519043, 0.4475591778755188, 0.44755837321281433, 0.4475574493408203, 0.44755658507347107, 0.4475557506084442, 0.44755488634109497, 0.4475540220737457, 0.4475531578063965, 0.447552353143692, 0.4475514888763428, 0.4475506842136383, 0.4475499093532562, 0.44754910469055176, 0.4475482404232025, 0.44754743576049805, 0.4475465714931488, 0.4475458562374115, 0.44754505157470703, 0.4475441873073578, 0.44754353165626526, 0.4475427269935608, 0.4475419223308563, 0.44754114747047424, 0.44754037261009216, 0.44753965735435486, 0.4475388526916504, 0.4475381374359131, 0.4475373923778534, 0.44753655791282654, 0.447535902261734, 0.44753512740135193, 0.44753438234329224, 0.44753366708755493, 0.4475329518318176, 0.4475322663784027, 0.447531521320343, 0.44753074645996094, 0.4475300908088684, 0.4475294053554535, 0.44752874970436096, 0.4475279748439789, 0.44752731919288635, 0.4475266635417938, 0.4475259482860565, 0.4475253224372864, 0.4475246071815491, 0.44752398133277893, 0.4475232660770416, 0.4475225508213043, 0.4475218951702118, 0.44752126932144165, 0.4475206434726715, 0.447519987821579, 0.44751930236816406, 0.44751864671707153, 0.4475180208683014, 0.44751739501953125, 0.4475167691707611, 0.44751617312431335, 0.44751548767089844, 0.4475148618221283, 0.44751426577568054, 0.4475136697292328, 0.44751301407814026, 0.4475123882293701, 0.4475117623806, 0.4475111961364746, 0.44751057028770447, 0.4475099742412567, 0.4475093483924866, 0.44750872254371643, 0.44750815629959106, 0.4475075900554657, 0.4475069046020508, 0.4475064277648926, 0.4475058615207672, 0.4475052058696747, 0.44750460982322693, 0.44750410318374634, 0.44750356674194336, 0.44750291109085083, 0.44750234484672546, 0.4475018084049225, 0.44750118255615234, 0.44750067591667175, 0.447500079870224, 0.4474995732307434, 0.4474990963935852, 0.44749853014945984, 0.4474979043006897, 0.4474973678588867, 0.44749686121940613, 0.4474962651729584, 0.44749578833580017, 0.4474951922893524, 0.4474947154521942, 0.44749417901039124, 0.44749361276626587, 0.44749316573143005, 0.4474925696849823, 0.4474920332431793, 0.44749146699905396, 0.44749101996421814, 0.44749051332473755, 0.4474899470806122, 0.447489470243454, 0.4474889934062958, 0.4474884569644928, 0.4474879205226898, 0.44748741388320923, 0.447486937046051, 0.44748640060424805, 0.44748592376708984, 0.44748541712760925, 0.44748494029045105, 0.44748449325561523, 0.44748398661613464, 0.44748345017433167, 0.44748303294181824, 0.44748249650001526, 0.44748201966285706, 0.44748154282569885, 0.44748106598854065, 0.44748058915138245, 0.44748011231422424, 0.44747963547706604, 0.44747909903526306, 0.44747868180274963, 0.44747820496559143, 0.447477787733078, 0.4474773108959198, 0.4474768340587616, 0.4474762976169586, 0.4474759101867676, 0.4474754333496094, 0.44747495651245117, 0.44747456908226013, 0.44747409224510193, 0.4474736154079437, 0.4474732577800751, 0.44747278094291687, 0.4474722445011139, 0.44747182726860046, 0.44747138023376465, 0.4474709630012512, 0.447470486164093, 0.4474700391292572, 0.447469562292099, 0.44746923446655273, 0.44746872782707214, 0.4474683403968811, 0.4474678635597229, 0.4474674165248871, 0.44746705889701843, 0.44746658205986023, 0.4474661350250244, 0.447465717792511, 0.44746533036231995, 0.4474649131298065, 0.4474644660949707, 0.4474639892578125, 0.44746360182762146, 0.44746318459510803, 0.4474627673625946, 0.44746237993240356, 0.4474619925022125, 0.4474615156650543, 0.4474611282348633, 0.44746074080467224, 0.4474603235721588, 0.4474599063396454, 0.44745951890945435, 0.44745904207229614, 0.4474586546421051, 0.44745832681655884, 0.447457879781723, 0.4474574029445648, 0.44745704531669617, 0.4474566876888275, 0.4474562704563141, 0.44745588302612305, 0.4474554657936096, 0.44745513796806335, 0.44745466113090515, 0.4474542737007141, 0.44745388627052307, 0.4474535584449768, 0.447453111410141, 0.44745275378227234, 0.4474523961544037, 0.44745197892189026, 0.4474515914916992, 0.44745126366615295, 0.44745078682899475, 0.4474504590034485, 0.44745007157325745, 0.4474496841430664, 0.44744929671287537, 0.4474489390850067, 0.4474485516548157, 0.44744816422462463, 0.4474477767944336, 0.44744738936424255, 0.4474470615386963, 0.44744673371315, 0.4474463164806366, 0.44744595885276794, 0.4474456012248993, 0.44744521379470825, 0.4474448263645172, 0.44744449853897095, 0.4474441707134247, 0.44744378328323364, 0.4474433958530426, 0.44744300842285156, 0.4474426209926605, 0.44744229316711426, 0.4474419057369232, 0.44744157791137695, 0.4474411904811859, 0.44744086265563965, 0.4474405348300934, 0.44744014739990234, 0.4474397599697113, 0.44743943214416504, 0.4474391043186188, 0.4474387764930725, 0.44743838906288147, 0.4474380612373352, 0.44743767380714417, 0.4474373757839203, 0.44743701815605164, 0.4474366307258606, 0.4474363327026367, 0.4474359452724457, 0.447435587644577, 0.4474352300167084, 0.4474349021911621, 0.4474346339702606, 0.4474342465400696, 0.4474339187145233, 0.44743356108665466, 0.4474332332611084, 0.44743284583091736, 0.4474325180053711, 0.44743219017982483, 0.44743186235427856, 0.4474315643310547, 0.44743117690086365, 0.4474308490753174, 0.4474305212497711, 0.44743022322654724, 0.44742995500564575, 0.4474295675754547, 0.44742918014526367, 0.4474289119243622, 0.44742852449417114, 0.4474281966686249, 0.447427898645401, 0.44742757081985474, 0.44742727279663086, 0.4474268853664398, 0.44742661714553833, 0.44742628931999207, 0.4474259316921234, 0.44742560386657715, 0.4474252760410309, 0.4474249482154846, 0.44742465019226074, 0.4474243223667145, 0.4474240243434906, 0.44742366671562195, 0.44742336869239807, 0.4474230706691742, 0.44742271304130554, 0.44742241501808167, 0.4474220871925354, 0.44742175936698914, 0.44742146134376526, 0.4474211633205414, 0.44742080569267273, 0.44742050766944885, 0.447420209646225, 0.4474198818206787, 0.44741955399513245, 0.44741925597190857, 0.4474188983440399, 0.44741860032081604, 0.44741830229759216, 0.4474180340766907, 0.4474177062511444, 0.4474174678325653, 0.44741708040237427, 0.4474167823791504, 0.44741639494895935, 0.44741615653038025, 0.447415828704834, 0.4474155604839325, 0.4474152624607086, 0.44741496443748474, 0.4474146068096161, 0.4474143087863922, 0.44741401076316833, 0.44741374254226685, 0.44741344451904297, 0.44741305708885193, 0.4474128782749176, 0.44741249084472656, 0.4474122226238251, 0.4474119246006012, 0.4474116265773773, 0.44741129875183105, 0.4474109709262848, 0.4474107325077057, 0.4474104940891266, 0.44741010665893555, 0.44740986824035645, 0.4474095404148102, 0.4474092125892639, 0.4474089741706848, 0.4474085867404938, 0.4474083483219147, 0.4474080204963684, 0.44740772247314453, 0.44740748405456543, 0.44740721583366394, 0.4474068582057953, 0.4474065899848938, 0.4474062919616699, 0.44740602374076843, 0.44740572571754456, 0.4474054276943207, 0.4474050998687744, 0.4474048614501953, 0.4474045932292938, 0.44740429520606995, 0.4474039673805237, 0.4474037289619446, 0.4474034309387207, 0.4474031627178192, 0.44740286469459534, 0.4474025368690491, 0.44740229845046997, 0.4474020004272461, 0.4474017322063446, 0.4474014341831207, 0.44740113615989685, 0.44740086793899536, 0.4474005401134491, 0.44740030169487, 0.4474000036716461, 0.44739970564842224, 0.44739943742752075, 0.4473991394042969, 0.4473988711833954, 0.4473985731601715, 0.44739827513694763, 0.44739800691604614, 0.44739776849746704, 0.44739753007888794, 0.44739723205566406, 0.4473969638347626, 0.4473966658115387, 0.4473963677883148, 0.44739609956741333, 0.44739580154418945, 0.44739553332328796, 0.44739529490470886, 0.447394996881485, 0.4473946988582611, 0.44739437103271484, 0.44739413261413574, 0.44739389419555664, 0.44739362597465515, 0.4473933279514313, 0.4473930299282074, 0.4473927915096283, 0.4473925530910492, 0.44739219546318054, 0.44739198684692383, 0.44739171862602234, 0.44739142060279846, 0.4473911225795746, 0.4473908841609955, 0.447390615940094, 0.4473903775215149, 0.447390079498291, 0.4473898112773895, 0.44738951325416565, 0.44738927483558655, 0.4473889470100403, 0.4473886489868164, 0.4473883807659149, 0.4473881721496582, 0.4473879337310791, 0.4473876655101776, 0.44738736748695374, 0.44738706946372986, 0.44738683104515076, 0.44738656282424927, 0.4473862648010254, 0.4473860263824463, 0.44738587737083435, 0.4473855197429657, 0.4473852813243866, 0.44738492369651794, 0.4473847448825836, 0.44738444685935974, 0.44738417863845825, 0.44738394021987915, 0.4473836123943329, 0.44738340377807617, 0.4473831355571747, 0.4473828375339508, 0.4473825991153717, 0.4473823606967926, 0.4473820626735687, 0.4473818242549896, 0.44738149642944336, 0.44738131761550903, 0.44738101959228516, 0.44738075137138367, 0.44738054275512695, 0.4473802149295807, 0.4473799765110016, 0.4473796784877777, 0.4473794996738434, 0.4473792016506195, 0.447378933429718, 0.4473786950111389, 0.4473784565925598, 0.4473782181739807, 0.44737792015075684, 0.44737768173217773, 0.44737741351127625, 0.44737711548805237, 0.44737693667411804, 0.44737663865089417, 0.44737645983695984, 0.4473761022090912, 0.4473758339881897, 0.447375625371933, 0.4473753869533539, 0.4473751187324524, 0.4473748207092285, 0.4473745822906494, 0.4473743438720703, 0.4473740756511688, 0.4473738372325897, 0.4473735988140106, 0.4473733603954315, 0.4473731219768524, 0.44737282395362854, 0.44737252593040466, 0.44737234711647034, 0.44737204909324646, 0.44737181067466736, 0.44737154245376587, 0.44737130403518677, 0.4473710060119629, 0.4473707675933838, 0.4473705291748047, 0.4473702609539032, 0.44737011194229126, 0.4473698139190674, 0.4473695456981659, 0.4473693072795868, 0.4473690092563629, 0.4473687708377838, 0.4473685920238495, 0.44736823439598083, 0.4473680555820465, 0.4473678171634674, 0.44736751914024353, 0.44736728072166443, 0.4473671019077301, 0.44736671447753906, 0.4473665654659271, 0.44736629724502563, 0.44736605882644653, 0.44736582040786743, 0.44736558198928833, 0.44736534357070923, 0.44736504554748535, 0.44736480712890625, 0.44736453890800476, 0.44736430048942566, 0.44736406207084656, 0.44736385345458984, 0.44736358523368835, 0.44736334681510925, 0.44736310839653015, 0.4473628103733063, 0.4473625719547272, 0.44736239314079285, 0.44736215472221375, 0.44736185669898987, 0.447361558675766, 0.44736137986183167, 0.4473610818386078, 0.44736090302467346, 0.4473606050014496, 0.4473603665828705, 0.4473601281642914, 0.44735994935035706, 0.4473596513271332, 0.4473594129085541, 0.44735923409461975, 0.4473589062690735, 0.4473586976528168, 0.44735845923423767, 0.4473581910133362, 0.4473579525947571, 0.447357714176178, 0.4473574757575989, 0.447357177734375, 0.4473569989204407, 0.4473567605018616, 0.44735652208328247, 0.44735631346702576, 0.44735604524612427, 0.44735583662986755, 0.44735556840896606, 0.44735535979270935, 0.4473550021648407, 0.44735485315322876, 0.4473545551300049, 0.4473543167114258, 0.4473540782928467, 0.4473538398742676, 0.44735369086265564, 0.44735342264175415, 0.44735318422317505, 0.44735288619995117, 0.44735270738601685, 0.44735240936279297, 0.44735223054885864, 0.44735199213027954, 0.44735169410705566, 0.44735145568847656, 0.44735121726989746, 0.44735097885131836, 0.44735080003738403, 0.44735056161880493, 0.44735026359558105, 0.44735002517700195, 0.4473498463630676, 0.44734954833984375, 0.44734930992126465, 0.44734907150268555, 0.4473489224910736, 0.44734859466552734, 0.447348415851593, 0.44734811782836914, 0.4473479390144348, 0.4473477005958557, 0.44734740257263184, 0.4473472237586975, 0.4473470151424408, 0.44734668731689453, 0.4473465085029602, 0.4473462700843811, 0.447346031665802, 0.4473458230495453, 0.4473455548286438, 0.4473453462123871, 0.447345107793808, 0.4473448693752289, 0.4473446309566498, 0.4473443925380707, 0.44734421372413635, 0.4473439157009125, 0.44734373688697815, 0.4473434388637543, 0.44734326004981995, 0.44734296202659607, 0.44734272360801697, 0.44734248518943787, 0.44734230637550354, 0.4473421275615692, 0.44734182953834534, 0.44734153151512146, 0.44734135270118713, 0.44734111428260803, 0.4473409354686737, 0.4473406970500946, 0.4473404586315155, 0.4473402202129364, 0.4473399221897125, 0.4473397433757782, 0.4473395049571991, 0.44733926653862, 0.4473390579223633, 0.4473387897014618, 0.4473385810852051, 0.447338342666626, 0.4473381042480469, 0.4473378658294678, 0.44733768701553345, 0.44733744859695435, 0.44733723998069763, 0.44733697175979614, 0.44733676314353943, 0.44733649492263794, 0.4473362863063812, 0.4473361074924469, 0.447335809469223, 0.4473356306552887, 0.4473353326320648, 0.4473350942134857, 0.4473348557949066, 0.4473346769809723, 0.4473343789577484, 0.4473342001438141, 0.447333961725235, 0.4473337233066559, 0.44733354449272156, 0.44733330607414246, 0.44733306765556335, 0.44733282923698425, 0.44733262062072754, 0.44733238220214844, 0.44733211398124695, 0.44733190536499023, 0.44733166694641113, 0.44733142852783203, 0.4473312497138977, 0.4473310112953186, 0.4473307728767395, 0.4473305642604828, 0.4473302960395813, 0.4473300874233246, 0.4473298490047455, 0.4473296105861664, 0.44732943177223206, 0.44732919335365295, 0.44732895493507385, 0.4473287761211395, 0.44732847809791565, 0.4473282992839813, 0.44732800126075745, 0.4473278224468231, 0.44732752442359924, 0.4473273456096649, 0.4473271369934082, 0.4473269581794739, 0.44732666015625, 0.4473264813423157, 0.4473261833190918, 0.44732600450515747, 0.44732576608657837, 0.44732552766799927, 0.44732531905174255, 0.44732505083084106, 0.4473249018192291, 0.44732466340065, 0.4473244845867157, 0.4473241865634918, 0.4473240077495575, 0.4473237991333008, 0.4473234713077545, 0.4473232328891754, 0.4473230838775635, 0.4473228454589844, 0.4473226070404053, 0.44732236862182617, 0.44732213020324707, 0.44732195138931274, 0.44732171297073364, 0.44732150435447693, 0.4473212659358978, 0.4473210275173187, 0.4473208487033844, 0.4473206102848053, 0.4473203718662262, 0.4473200738430023, 0.447319895029068, 0.44731971621513367, 0.44731950759887695, 0.44731926918029785, 0.44731903076171875, 0.4473188519477844, 0.44731855392456055, 0.4473183751106262, 0.4473181664943695, 0.447317898273468, 0.4473176896572113, 0.447317510843277, 0.4473172128200531, 0.4473170340061188, 0.4473167955875397, 0.44731664657592773, 0.44731637835502625, 0.44731613993644714, 0.44731590151786804, 0.44731569290161133, 0.44731542468070984, 0.4473152160644531, 0.4473150372505188, 0.4473147988319397, 0.4473145604133606, 0.4473143517971039, 0.4473141133785248, 0.44731393456459045, 0.4473136067390442, 0.4473133981227875, 0.44731321930885315, 0.4473130404949188, 0.4473128020763397, 0.4473125636577606, 0.4473123550415039, 0.4473121166229248, 0.4473118782043457, 0.44731172919273376, 0.44731149077415466, 0.4473112225532532]
7、性能评价
使用测试集对训练中的最优模型进行评价,观察模型的评价指标。
# 加载训练好的模型
runner.load_model(model_saved_dir)
# 在测试集上对模型进行评价
score, loss = runner.evaluate([X_test, y_test])
print("[Test] score/loss: {:.4f}/{:.4f}".format(score, loss))
执行结果:
tensor([[0.4744],
[0.5329],
[0.4848],
[0.4728],
[0.4804],
[0.4630],
[0.4498],
[0.4749],
[0.5298],
[0.4805],
[0.4620],
[0.4909],
[0.5134],
[0.4707],
[0.4737],
[0.5498],
[0.5323],
[0.5199],
[0.4675],
[0.4716],
[0.4822],
[0.5271],
[0.4620],
[0.5316],
[0.5176],
[0.5211],
[0.5580],
[0.5143],
[0.5115],
[0.4846],
[0.5328],
[0.4769],
[0.4560],
[0.4930],
[0.4477],
[0.5363],
[0.4988],
[0.5096],
[0.4992],
[0.4711],
[0.5205],
[0.4958],
[0.5194],
[0.5491],
[0.5288],
[0.4927],
[0.4851],
[0.4635],
[0.4819],
[0.5298],
[0.4950],
[0.5350],
[0.5173],
[0.5157],
[0.4662],
[0.5249],
[0.5072],
[0.5069],
[0.4808],
[0.4557],
[0.5154],
[0.4978],
[0.4903],
[0.4641],
[0.4888],
[0.5599],
[0.5067],
[0.4791],
[0.5196],
[0.4609],
[0.5209],
[0.4763],
[0.5284],
[0.4648],
[0.4667],
[0.4636],
[0.4941],
[0.4604],
[0.4827],
[0.4798],
[0.4740],
[0.5379],
[0.4925],
[0.4643],
[0.4798],
[0.5443],
[0.4631],
[0.4580],
[0.4650],
[0.4802],
[0.5339],
[0.5353],
[0.4810],
[0.5114],
[0.4850],
[0.5080],
[0.4954],
[0.4683],
[0.5202],
[0.4802],
[0.5195],
[0.4908],
[0.4586],
[0.4997],
[0.4685],
[0.5147],
[0.4721],
[0.5158],
[0.5167],
[0.4838],
[0.4670],
[0.4802],
[0.4630],
[0.5042],
[0.4630],
[0.4669],
[0.4643],
[0.4592],
[0.5424],
[0.5286],
[0.5291],
[0.4446],
[0.5018],
[0.4748],
[0.5085],
[0.5466],
[0.4856],
[0.4834],
[0.4474],
[0.5065],
[0.4900],
[0.4992],
[0.5424],
[0.4688],
[0.4437],
[0.4858],
[0.5406],
[0.5076],
[0.4789],
[0.4702],
[0.5492],
[0.4390],
[0.5382],
[0.4770],
[0.5090],
[0.4848],
[0.4439],
[0.4949],
[0.4924],
[0.5227],
[0.5136],
[0.4677],
[0.4978],
[0.5130],
[0.4746],
[0.4846],
[0.4993],
[0.5029],
[0.5273],
[0.4645],
[0.5309],
[0.5214],
[0.5340],
[0.5056],
[0.4546],
[0.4640],
[0.4979],
[0.4989],
[0.4956],
[0.5110],
[0.5407],
[0.5388],
[0.5267],
[0.5125],
[0.5289],
[0.4579],
[0.4945],
[0.4983],
[0.5056],
[0.5207],
[0.4583],
[0.5165],
[0.4742],
[0.4875],
[0.4472],
[0.5114],
[0.5056],
[0.5253],
[0.4890],
[0.4966],
[0.5308],
[0.5063],
[0.4728],
[0.5236],
[0.4960],
[0.5054],
[0.5244],
[0.4648],
[0.5113],
[0.4886]])
[Test] score/loss: 0.5150/0.7322
对性能评价结果进行可视化处理
未降低噪声时:
降低噪声后:
思考题:
对比3.1 基于Logistic回归的二分类任务 4.2 基于前馈神经网络的二分类任务
都使用了交叉熵损失函数作为损失函数,但神经网络明显要比logistic回归的复杂程度要高很多,logistic是一个线性模型用在处理线性可分的问题,而神经网络能够处理非线性的问题,现实生活中大多数的问题都是线性不可分的,用途的更广一些,另外由于神经网络有了激活函数等等,在用时上也比logistic要多,但有时候用时多也不一定会分类很好。
三、体会
这次实验使用类来封装模型的训练过程,掌握了模型建立的过程,在使用弯月数据集拟合的时候也出现了一些错误和拟合不好的结果,通过学习其他同学的方式进行了降低噪声或者新的弯月数据集也拟合出了好的结果,另外在paddle向pytorch转换的过程中也会出现一些错误,还有很多需要努力的地方。
参考:
2.5. 自动微分 — 动手学深度学习 2.0.0-beta1 documentation