PaddlePaddle 2.0和PyTorch风格还是非常像的。使用PaddlePaddle可以直接调用百度AI Studio里的一些资源(包括GPU、预训练权重之类的),而且说明文档、社区都是中文的,比较友好;而PyTorch在Github有更多的代码与资源,两者配合使用是比较香的。下面整理了一些PaddlePaddle以及PyTorch中对应的函数。当然,最好的使用方法是知道对应关系之后,
去PyTorch、PaddlePaddle官网上的数据手册查看具体说明
1 常用的包
PyTorch | PaddlePaddle | 说明 |
---|
torch.nn | paddle.nn | 包括了神经网络相关的大部分函数 |
nn.Module | nn.Layer | 搭建网络时集成的父类,包含了初始化等基本功能 |
torch.optim | paddle.optimizer | 训练优化器 |
torchvision.transforms | paddle.vision.transforms | 数据预处理、图片处理 |
torchvision.datasets | paddle.vision.datasets | 数据集的加载与处理 |
| | |
| | |
| | |
| | |
2 网络结构
这一部分函数的输入参数基本是一致的,有不一致的地方会特别说明
PyTorch | PaddlePaddle | 说明 |
---|
nn.Conv2d | nn.Conv2D | 2维卷积层 |
nn.BatchNorm2d | nn.BatchNorm2D | Batch Normalization 归一化 |
nn.ReLU | nn.ReLU | ReLU激活函数 |
nn.MaxPool2d | nn.MaxPool2D | 二维最大池化层 |
nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool2D | 自适应二维平均池化(只用给定输出形状即可) |
nn.Linear | nn.Linear | 全连接层 |
nn.Sequential | nn.Sequential | 顺序容器,用来添加layers |
torch.flatten | paddle.flatten | 展平处理 |
torch.softmax | paddle.softmax | softmax层 |
| | |
| | |
| | |
| | |
| | |
3 数据加载与处理
PyTorch | PaddlePaddle | 说明 |
---|
transforms.Compose | transforms.Compose | 图片处理打包 |
transforms.RandomResizedCrop | transforms.RandomResizedCrop | 随机裁剪 |
transforms.RandomHorizontalFlip | transforms.RandomHorizontalFlip | 随机水平翻转 |
transforms.ToTensor | transforms.ToTensor | 转化为tensor格式 |
transforms.Normalize | transforms.Normalize | 数据标准化 |
datasets.ImageFolder | datasets.DatasetFolder | 指定数据集文件夹 |
torch.utils.data.DataLoader | paddle.io.DataLoader | 加载数据集 |
| | |
4 模型训练
这里括号表示为用户自己定义的变量名
PyTorch | PaddlePaddle | 说明 |
---|
(net).train | (net).train | 训练模式 |
(loss).backward | (loss).backward | 反向传递误差 |
optim.Adam | optim.Adam | Adam优化器,注意paddlepaddle中的参数分别为parameters和learning _rate,与PyTorch中是不同的 |
(optimizer).no_grad | (optimizer).zero_grad | 梯度清零 |
torch.save | paddle.jit.save | 说实话,这两个还是有点区别的,使用请看官方文档 |
(net).eval | (net).eval | 预测模式 |
| | |
| | |
| | |
5 模型预测
PyTorch | PaddlePaddle | 说明 |
---|
torch.unsqueeze | paddle.unsqueeze | 增加数据维度 |
torch.no_grad | paddle.no_grad | 不计算梯度 |
| | |
6 其它
PyTorch | PaddlePaddle | 说明 |
---|
torch.device | paddle.set_device | 指定设备 |
| | |
| | |
| | |