点击下方“JavaEdge”,选择“设为星标”
第一时间关注技术干货!
免责声明~
任何文章不要过度深思!
万事万物都经不起审视,因为世上没有同样的成长环境,也没有同样的认知水平,更「没有适用于所有人的解决方案」;
不要急着评判文章列出的观点,只需代入其中,适度审视一番自己即可,能「跳脱出来从外人的角度看看现在的自己处在什么样的阶段」才不为俗人。
怎么想、怎么做,全在乎自己「不断实践中寻找适合自己的大道」
0 学习目标
掌握使用PyTorch构建神经网络的基本流程和实现过程。PyTorch是一个强大的深度学习框架,其核心工具集中在torch.nn包中。这个包依赖于自动求导(autograd)机制来定义模型并计算梯度,省去了手动编写复杂数学公式的需求。
对于Java开发者来说,PyTorch的神经网络构建类似于设计一个复杂的Java类系统:你需要定义类、方法和字段,并通过循环和算法优化来处理数据和学习。
构建神经网络的流程
以下是构建神经网络的6个核心步骤,类似Java中开发一个数据处理系统:
定义一个包含可学习参数的神经网络(类似Java类和字段)。
遍历训练数据集(类似Java中的
for循环或迭代器)。处理输入数据使其流经神经网络(前向传播,类似数据流经对象方法)。
计算损失值(评估预测结果与真实结果的差距,类似Java中的误差计算方法)。
将网络参数的梯度进行反向传播(类似Java中根据误差调整字段值)。
按照特定规则更新网络权重(类似Java中的优化算法,如梯度下降)。
1 定义一个神经网络
PyTorch中的神经网络通过定义一个类来实现,类似于Java中的面向对象编程。以下是代码和详细解释:
代码示例
ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineimport torchimport torch.nn as nnimport torch.nn.functional as F
# 定义神经网络类,继承自nn.Module(类似Java的继承)class Net(nn.Module): def __init__(self): super(Net, self).__init__() # 第一层卷积层:输入1个通道(例如灰度图像),输出6个通道,卷积核大小3x3 self.conv1 = nn.Conv2d(1, 6, 3) # 第二层卷积层:输入6个通道,输出16个通道,卷积核大小3x3 self.conv2 = nn.Conv2d(6, 16, 3) # 三层全连接层:展平后的输入维度为16 * 6 * 6,输出维度依次为120、84、10 self.fc1 = nn.Linear(16 * 6 * 6, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10)
# 前向传播方法,定义数据如何流经网络(类似Java方法) def forward(self, x): # 卷积 + 激活(ReLU) + 最大池化 x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # 再次卷积 + 激活 + 最大池化 x = F.max_pool2d(F.relu(self.conv2(x)), 2) # 展平数据,方便全连接层处理(类似Java中矩阵转置或展平) x = x.view(-1, self.num_flat_features(x)) # 全连接层 + 激活 x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) # 输出层 x = self.fc3(x) return x
# 计算展平后的特征数量(辅助方法) def num_flat_features(self, x): # 忽略批次维度(第0维),计算其余维度的乘积 size = x.size()[1:] # 类似Java中获取数组维度 num_features = 1 for s in size: num_features *= s return num_features
# 创建网络实例net = Net()print(net) # 打印网络结构详细讲解(面向Java开发者)
类定义:
Net类继承nn.Module,类似于Java中的class Net extends BaseModel。nn.Module提供了管理参数和网络层的基础功能。初始化(
__init__):相当于Java的构造函数,这里定义网络的层:nn.Conv2d是卷积层,类似图像处理中的滤波器,参数表示输入/输出通道数和卷积核大小。nn.Linear是全连接层,类似Java中的矩阵乘法,输入和输出维度由参数指定。
前向传播(
forward):定义数据如何通过网络处理,类似于Java中的方法调用链。F.relu是激活函数(类似Math.max(0, x)),F.max_pool2d是池化操作(降采样,减少数据维度)。参数管理:
net.parameters()返回所有可训练参数(权重和偏置),类似于Java中List<Matrix>,可以用来迭代更新。
测试网络
ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(line# 生成随机输入数据,形状为(1, 1, 32, 32):1个样本、1个通道、32x32像素input = torch.randn(1, 1, 32, 32)out = net(input) # 前向传播,输出形状为(1, 10),表示10个类别的预测print(out)
# 获取并查看参数params = list(net.parameters())print(len(params)) # 参数组数(每层权重和偏置)print(params[0].size()) # 第一个参数的形状(例如权重矩阵维度)
# 清零梯度并执行反向传播net.zero_grad()out.backward(torch.randn(1, 10)) # 模拟反向传播输入要求:PyTorch的
nn.Conv2d需要4D张量(nSamples, nChannels, Height, Width),不支持单样本输入。如果输入是3D张量,需要用input.unsqueeze(0)扩展为4D,类似Java中手动调整数组维度。
2 损失函数
损失函数用来评估网络输出(预测)和目标(真实值)之间的差距,类似于Java中定义一个误差计算方法。
代码示例
ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(line# 获取网络输出output = net(input)# 生成随机目标,形状为(10,),调整为(1, 10)与output匹配target = torch.randn(10)target = target.view(1, -1) # 调整形状
# 定义均方误差损失函数criterion = nn.MSELoss()# 计算损失值,一个标量表示预测与目标的差距loss = criterion(output, target)print(loss)
# 查看损失的计算图print(loss.grad_fn) # 输出损失的计算节点(如MSELoss)print(loss.grad_fn.next_functions[0][0]) # 线性层print(loss.grad_fn.next_functions[0][0].next_functions[0][0]) # ReLU详细讲解
损失函数:
nn.MSELoss计算均方误差((output - target)^2的平均值),类似于Java中的平方差计算。计算图:PyTorch自动构建计算图(类似Java的依赖树),
loss.grad_fn显示计算路径(如input -> conv2d -> relu -> ... -> MSELoss -> loss)。自动求导:当调用
loss.backward()时,PyTorch会对所有requires_grad=True的张量计算梯度,并累加到.grad属性,类似Java中追踪对象依赖并更新字段。
3 反向传播(Backpropagation)
反向传播是神经网络训练的核心,用于计算每个参数的梯度,进而优化模型。
代码示例
ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(line# 清零梯度,防止累加net.zero_grad()
# 打印反向传播前的梯度print('conv1.bias.grad before backward')print(net.conv1.bias.grad) # 初始为0
# 执行反向传播loss.backward()
# 打印反向传播后的梯度print('conv1.bias.grad after backward')print(net.conv1.bias.grad) # 梯度变为非零值详细讲解
梯度清零:
net.zero_grad()清空所有参数的梯度,类似于Java中重置计数器或清空缓存,防止不同批次数据的梯度累加。反向传播:
loss.backward()从损失值开始,沿着计算图反向计算每个参数的梯度,存入.grad属性,类似Java中根据误差反向调整对象字段。重要性:确保每次训练迭代开始时梯度为0,避免累积误差影响训练。
4 更新网络参数
训练的最后一步是根据梯度更新参数,最常见的方法是随机梯度下降(SGD)。更新公式为:
传统方法
ounter(lineounter(lineounter(linelearning_rate = 0.01for f in net.parameters(): f.data.sub_(f.grad.data * learning_rate) # 手动更新参数PyTorch优化器(推荐)
ounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineounter(lineimport torch.optim as optim
# 创建SGD优化器,管理网络参数,学习率为0.01optimizer = optim.SGD(net.parameters(), lr=0.01)
# 训练步骤optimizer.zero_grad() # 清零梯度output = net(input) # 前向传播loss = criterion(output, target) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新参数详细讲解
手动SGD:直接用公式更新参数,类似Java中循环遍历对象字段并应用算法。
优化器:
optim.SGD是PyTorch提供的工具,封装了SGD逻辑,类似Java中的优化算法库,简化了参数更新。步骤:
optimizer.step()自动根据梯度更新所有参数,比手动循环更高效。
5 总结
构建神经网络的典型流程
定义网络:创建一个包含可学习参数的类(类似Java类),定义层和前向传播逻辑。
遍历数据:使用数据集循环训练(类似Java的
for循环)。前向传播:处理输入数据得到预测(类似方法调用链)。
计算损失:用损失函数(如
nn.MSELoss)评估预测与目标差距。反向传播:通过
loss.backward()计算梯度(类似Java中反向调整字段)。更新参数:用优化器(如
optim.SGD)更新权重(类似Java中优化算法)。
关键技术点
损失函数:
nn.MSELoss计算均方误差,loss.backward()触发自动求导,更新requires_grad=True张量的.grad属性。反向传播:
net.zero_grad()清零梯度,loss.backward()计算梯度,确保每次迭代独立。参数更新:通过优化器(如SGD)实现,公式为
weight = weight - learning_rate * gradient。
面向Java开发者的类比
构建PyTorch神经网络就像设计一个Java系统:
网络类
Net类似于Java类,层(如nn.Conv2d、nn.Linear)是字段。前向传播和反向传播是方法调用,损失函数和优化器像是Java中的工具类。
PyTorch的自动求导省去了手动计算梯度的麻烦,类似Java框架(如Spring)自动管理依赖。
🚀 魔都架构师 | 全网30W+技术追随者
🔧 大厂分布式系统/数据中台实战专家
🏆 主导交易系统亿级流量调优 & 车联网平台架构
🧠 AIGC应用开发先行者 | 区块链落地实践者
🌍 以技术驱动创新,我们的征途是改变世界!
👉 实战干货:编程严选网
关注我,紧跟本系列专栏文章,咱们下篇再续!
扫码加入 DeepSeek 使用交流群:

写在最后
编程严选网:
http://www.javaedge.cn/专注分享软件开发全场景最佳实践,点击文末【阅读原文】即可直达~
279

被折叠的 条评论
为什么被折叠?



