0、映射表以及差异
PyTorch-PaddlePaddle API映射表
1、库
| 名称 | Pytorch | PaddlePaddle |
|---|
| Dataset | from torch.utils.data import Dataset | from paddle.io import Dataset |
| DataLoader | from torch.utils.data import DataLoader | from paddle.io import DataLoader |
| transforms | from torchvision import transforms | from paddle.vision import transforms |
| get device | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | device = paddle.device.get_device() |
| set device | var.to(device) | paddle.device.set_device(device)(全局设置一次即可) |
2、API
| 名称 | Pytorch | PaddlePaddle |
|---|
| layer | nn.Module | nn.Layer |
| 各种层 | nn.layer2d | nn.layer2D(即paddle使用大写D) |
| concat | torch.cat | paddle.concat |
| flatten | var.view(var.size(0), -1) | nn.Flatten |
| optim | torch.optim | paddle.optimizer |
| 展平 | x.view(x.size(0), -1) | nn.Flatten() |