很多同学说每次使用PyTorch
时都需要导入很多模块,非常混乱,今天我就将PyTorch
常用的模块做一个总结梳理。
首先要说明的是PyTorch
这是torch
的Python
版本,所以导入的是torch
而不是Pytorch
:
import torch
1 运行基础
torch.tensor
:基础数据结构torch.autograd
:自动微分模块
2 torch.utils
支持神经网络相关的数据预处理。
-
数据导入与处理
utils.data
utils.datasets
-
utils.tensorboard
:训练结果的可视化 -
utils.model_zoo
:预训练模型
3 torch.nn
构建神经网络结构的基本元素。
nn.Module
:神经网络的各种结构“层”nn.functional
:神经网络的损失函数与激活函数
4 torch.optim
神经网络的算法优化模块,封装着各类优化器。
5 神经网络的运算性能
torch.torchelastic
:分布式训练torch.cuda
:在GPU
上训练
6 torch.JIT
生产环境中部署的模块。
与torch
并列的库
下面介绍几个经常导入,但其实与torch
是一个级别的库。
7 torchvision
计算机视觉
因为7之前的内容都是torch.xx
,所以torchvision
是与torch
等级的库。
import torchvision
torchvision.datasets
:CV常用数据torchvision.models
:CV常用模型torchvision.transforms
:图像数据的预处理工具
8 torchtext
自然语言处理
torchtext.data
:文字数据的数据预处理torchtext.datasets
:NLP领域的常用数据集
9 torchaudio
语音处理
torchaudio.datasets
:语音领域的常用数据集torchaudio.transforms
:语音领域的预处理工具torchaudio.models
:语音领域的常用模型torchaudio.functional
:语音领域的常用函数