PyTorch Cookbook——常用代码整理

本文参考自 PyTorch Cookbook(常用代码段整理合集)

训练代码示例

先放个模型训练的整个 .py 代码范例:Pytorch 训练模型代码范例

检查版本

torch.__version__               # PyTorch version
torch.version.cuda              # Corresponding CUDA version
torch.backends.cudnn.version()  # Corresponding cuDNN version
torch.cuda.is_available() 		# 判断是否有CUDA支持

关于什么是 CUDA, cuDNN,以及 NVCC, CUDA Toolkit 又是什么,可以参考我的另一篇文章:PyG 安装以及关于 CUDA 的扩展

cuDNN benchmark 模式

torch.backends.cudnn.benchmark = True 时,大大提升卷积神经网络的运行速度。

卷积前向传播的实现有许多种实现方式。benchmark 模式下,程序会在开始时花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。
适用场景:网络结构固定,输入形状不变。具体来说,输入的 batch size,宽和高,输入通道的个数;卷积层本身的参数,包括卷积核大小,stride,padding ,输出通道的个数,这些都要固定,才能实现较好的加速效果。否则可能得不偿失。——torch.backends.cudnn.benchmark ?!

具体使用:在代码开头设置即可

torch.backends.cudnn.benchmark = True

torch.no_grad() 和 model.eval()

with torch.no_grad() 是一个上下文管理器 (context manager),它会创建一个环境,在此之内的张量运算均不会计算梯度。
model.eval() 会改变模型中某些层的前向传播行为,如 BatchNorm layer, Dropout layer.

它们实际上在做完全不一样的事情。在做模型推理时,一定要设置 model.eval();最好也用 with torch.no_grad,节省计算资源。

model.eval()
with torch.no_grad():
	output = model(input)
	...

关于自动求导和 Pytorch 的动态计算图,详见 Pytorch 中的自动求导与(动态)计算图


结果可复现

参考:Pytorch: REPRODUCIBILITY

设置随机种子:

np.random.seed(0)
torch.manual_seed(0) # seed the RNG for all devices (both CPU and CUDA)

Disable benchmark 模式:卷积操作只会采用一种固定的算法:

torch.backends.cudnn.benchmark = False # force cuDNN to deterministically select an convolution algorithm

选定的卷积算法可能本身是 non deterministic。规定 Pytorch 只能使用 deterministic 的卷积算法:

torch.backends.cudnn.deterministic=True 

规定所有运算只能使用 deterministic 算法(如果某种运算没有相应的 deterministic 算法实现,就会报错):

torch.use_deterministic_algorithms(True)

设置 Dataloader:

DataLoader will reseed workers following Randomness in multi-process data loading algorithm. Use worker_init_fn() and generator to preserve reproducibility:

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    worker_init_fn=seed_worker,
    generator=g,
)

我在另一篇文章 可复现的 PyTorch 中有更详细的讨论。

张量命名

为张量的每一个维度取一个名字,提高可读性,防止出错。也可以使用维度的名字来做索引或其他操作。

images = torch.randn(32, 3, 56, 56, names=('B', 'C', 'H', 'W'))
images.sum('C')
images.select('C', index=0)

打乱顺序

tensor = tensor[torch.randperm(tensor.size(0))]  # 打乱第一个维度

复制张量

# Operation                 |  New/Shared memory | Still in computation graph |
tensor.clone()            # |        New         |          Yes               |
tensor.detach()           # |      Shared        |          No                |
tensor.detach.clone()()   # |        New         |          No                |

拼接张量

t1 = torch.randn(10,5)
t2 = torch.randn(10,5)
print("Concat list of tensors: ", torch.cat((t1, t2), dim=0).shape)
print("Stack list of tensors: ", torch.stack((t1, t2), dim=0).shape)
# Concat list of tensors:  torch.Size([20, 5])
# Stack list of tensors:  torch.Size([2, 10, 5])

torch.cat 沿着给定的维度拼接,而 torch.stack 会新增一维。

矩阵乘法

下面几个都可以做矩阵乘法,你知道它们的区别吗?

  • torch.mm - performs a matrix multiplication without broadcasting - (2D tensor) by (2D tensor)
  • torch.mul - performs a elementwise multiplication with broadcasting - (Tensor) by (Tensor or Number)
  • torch.matmul - matrix product with broadcasting - (Tensor) by (Tensor) with different behaviors depending on the tensor shapes (dot product, matrix product, batched matrix products).
  • torch.bmm - batched matrix products without broadcasting.
t3 = torch.randn(3, 5, 8)
t4 = torch.randn(3, 8, 9)
torch.bmm(t3, t4).shape
# torch.Size([3, 5, 9])

cf. What’s the difference between torch.mm, torch.matmul and torch.mul?

学习率规划器

Pytorch 常见的 scheduler 可以看这篇文章:Pytorch 学习率规划器

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值