FairScale 是 PyTorch 扩展库,用于在一台或多台机器/节点上进行高性能和大规模训练
facebookbook发布的
fairscale支持:
- pipeline 并行(fairscale.nn.Pipe)
- 优化state sharding(fairscale.optim.oss)
【示例】
在2个GPU上运行4层模型。前两层在cuda:0上运行,后两层在cuda:1上运行。
from torch import nn import fairscale model = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=6, kernel_size=(5,5), stride=1, padding=0), nn.MaxPool2d(kernel_size=(2,2), stride=2, padding=0), nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5,5), stride=1, padding=0), nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=0), ) model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)
FairScale——用于高性能和大规模训练的PyTorch库
于 2021-09-05 10:28:48 首次发布