本节主要基于Pytorch实现线性回归和非线性回归。帮助小白熟悉Pytorch的使用,深度理解神经网络的工作原理。为便于阅读与理解,本节将建立两个文件,分别为“Pytorch_Network_01.py”和“Pytorch_Train_01.py”。“Pytorch_Network_01.py”文件主要包含线性网络和非线性网络类,“Pytorch_Train_01.py”文件用于网络训练。
Pytorch_Network_01.py
'''
文件功能: 搭建全连接神经网络, 并进行初始化
线性网络: Linear_Net
非线性网络: Nonlinear_Net
'''
import torch
import torch.nn as nn
# 线性网络
class Linear_Net(nn.Module):
def __init__(self, input_dim, output_dim):
super(Linear_Net, self).__init__()
self.input_layer = nn.Linear(input_dim, 5)
self.output_layer = nn.Linear(5, output_dim)
def forward(self, x):
x1 = self.input_layer(x)
y = self.output_layer(x1)
return y
def init(self):
self.input_layer.weight.data = torch.randn(self.input_layer.weight.size())
self.output_layer.weight.data =