学习目标:
`
- 学习Pytorch 主要模块
参考内容:https://linklearner.com/datawhalehomepage/index.html#/learn/detail/92
1.基本配置:
`对GPU的设置有两种常见的方式:
```python
# 方案一:使用os.environ,这种情况如果使用GPU不需要设置
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
# 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
2.数据读入:
PyTorch数据读入是通过Dataset+DataLoader的方式完成的,Dataset定义好数据的格式和数据变换形式,DataLoader用iterative的方式不断读入批次数据。
例如:
- init: 用于向类中传入外部参数,同时定义样本集
- getitem: 用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据
- len: 用于返回数据集的样本数
3.模型构建:
4.损失函数:
一个模型想要达到很好的效果需要学习,也就是我们常说的训练。一个好的训练离不开优质的负反馈,这里的损失函数就是模型的负反馈。
5.优化器设置
优化器是根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值,使得模型输出更加接近真实标签。