nn.Module
神经网络可以看做 layers 构成的复杂 nested structure。为了让构建网络的代码清晰且结构化,pytorch这样的引擎封装了大量 building blocks 用于构建网络,它们都继承自 torch.nn.Module。
Neural networks comprise of layers/modules that perform operations on data. The torch.nn namespace provides all the building blocks you need to build your own neural network. Every module in PyTorch subclasses the nn.Module. A neural network is a module itself that consists of other modules (layers). This nested structure allows for building and managing complex architectures easily.
常见的导入包:
import torch
from torch import nn
Get Device for Training
通过指定 device 我们可以在硬件加速器(如 GPU)上加速训练过程。
device = "cuda" if torch.cuda.is_available() else "cpu"
Define the Class
定义神经网络类的常规写法:
- init() 函数中定义 layers 和操作
- forward() 函数中定义前向过程(即如何将这些 layers 组合起来)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512,