承接 上一篇,使用原始的pytorch来实现多GPU训练和混合精度,现在对比以上代码,我们使用Fabric来实现相同的功能。关于Fabric,我会在后续的博客中继续讲解,是讲解,也是在学习。通过fabric,可以减少代码量,同时提升开发速度,大家用的多的可能是hugging face的accelerate,这个我也看过,但需要在外边先对yaml进行配置,也是很方便的。但在日常中,我的工作主要是单机多卡,也喜欢灵活,用fabric也就成为一种好的选择,能很方便配置分布式训练ddp,fsdp,deepspeed,硬件选择 ,是否混精训练等。
相比上一篇,模型稍微改了一下,只是为了查看对bn的影响。直接上代码:
1、假数据处理
import torch
from torch import nn
from lightning import Fabric
from torchinfo import summary
def train(num_epochs,model,optimizer,data,target,fabric):
model.train()
data=fabric.to_device(data)
target=fabric.to_device(target)
#data=data.to(fabric.device)
#target=target.to(fabric.device)
print("fabric.device and local_rank and torch local rank:",fabric.device,fabric.local_rank,torch.distributed.get_rank())# 这三个是一个东西
for epoch in range(num_epochs):
out=model(data)
loss = torch.nn.MSELoss()(out,target)
optimizer.zero_grad()
fabric.backward(loss)
optimizer.step()
print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | train loss:{loss}") #会打印出每个GPU上的loss
all_loss=fabric.all_gather(loss) #获取所有loss,这个是一样大的,GPU个loss
print(all_loss)
#保存模型
state={"model":model,"optimizer":optimizer,"iter":epoch+1}
fabric.save("checkpoint.ckpt",state)
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv=nn.Conv2d(3,5,3,1)
self.bn = nn.BatchNorm2d(5)
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
self.flat = nn.Flatten()
self.fc = nn.Linear(5, 1)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.avg_pool(x)
x = self.flat(x)
x = self.fc(x)
return x
if __name__=="__main__":
fabric = Fabric(accelerator="cuda",devices=[0,1],strategy="ddp",precision='16-mixed')
fabric.launch()
fabric.seed_everything()
#初始化模型
model = SimpleModel()
fabric.print(f"before setup model,state dict:")#只在GPU0上打印
#fabric.print(summary(model,input_size=(1,3,8,8)))
fabric.print(model.state_dict().keys())
fabric.print("*****************************************************************")
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
if fabric.world_size>1:
model=torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
fabric.print(f"after convert bn to sync bn,state dict:")
#fabric.print(summary(model,input_size=(1,3,8,8)))
print(f"after convert bn to sync bn device:{fabric.device} conv.weight.device:{model.conv.weight.device}")
fabric.print(model.state_dict().keys())
fabric.print("*****************************************************************")
model,optimizer=fabric.setup(model,optimizer)
print(f"after setup device:{fabric.device} conv.weight.device:{model.conv.weight.device}")
fabric.print(f"after setup model,model state dict:")
#fabric.print(summary(model,input_size=(1,3,8,8)))
fabric.print(model.state_dict().keys())
#设置模拟数据(如果是dataloader那么除了torch.utils.data.DistributedSampler外的其它部分)
data= torch.rand(5,3,8,8)
target=torch.rand(5,1)
#开始训练
epoch=100
train(epoch,model,optimizer,data,target,fabric)
输出结果:
Using 16-bit Automatic Mixed Precision (AMP)
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------
/home/tl/anaconda3/envs/ptch/lib/python3.10/site-packages/lightning/fabric/utilities/seed.py:40: No seed found, seed set to 3183422672
[rank: 0] Seed set to 3183422672
before setup model,state dict:
odict_keys(['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias', 'bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked', 'fc.weight', 'fc.bias'])
*****************************************************************
after convert bn to sync bn,state dict:
after convert bn to sync bn device:cuda:0 conv.weight.device:cpu
odict_keys(['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias', 'bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked', 'fc.weight', 'fc.bias'])
*****************************************************************
[rank: 1] Seed set to 1590652679
after convert bn to sync bn device:cuda:1 conv.weight.device:cpu
after setup device:cuda:1 conv.weight.device:cuda:1
after setup device:cuda:0 conv.weight.device:cuda:0
after setup model,model state dict:
odict_keys(['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias', 'bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked', 'fc.weight', 'fc.bias'])
fabric.device and local_rank and torch local rank: cuda:1 1 1
fabric.device and local_rank and torch local rank: cuda:0 0 0
Epoch: 0001/0100 | train loss:0.5391270518302917
Epoch: 0001/0100 | train loss:0.4002908766269684
tensor([0.5391, 0.4003], device='cuda:0')
tensor([0.5391, 0.4003], device='cuda:1')
Epoch: 0002/0100 | train loss:0.5391270518302917
Epoch: 0002/0100 | train loss:0.4002908766269684
tensor([0.5391, 0.4003], device='cuda:0')
tensor([0.5391, 0.4003], device='cuda:1')
Epoch: 0003/0100 | train loss:0.3809531629085541
Epoch: 0003/0100 | train loss:0.5164263844490051
tensor([0.5164, 0.3810], device='cuda:1')
tensor([0.5164, 0.3810], device='cuda:0')
Epoch: 0004/0100 | train loss:0.3625626266002655
Epoch: 0004/0100 | train loss:0.49487170577049255
tensor([0.4949, 0.3626], device='cuda:0')
tensor([0.4949, 0.3626], device='cuda:1')
Epoch: 0005/0100 | train loss:0.34520527720451355
Epoch: 0005/0100 | train loss:0.47438523173332214
tensor([0.4744, 0.3452], device='cuda:1')
tensor([0.4744, 0.3452], device='cuda:0')
Epoch: 0006/0100 | train loss:0.32876724004745483
Epoch: 0006/0100 | train loss:0.45497187972068787
tensor([0.4550, 0.3288], device='cuda:1')
tensor([0.4550, 0.3288], device='cuda:0')
Epoch: 0007/0100 | train loss:0.4365047514438629
Epoch: 0007/0100 | train loss:0.31321704387664795
tensor([0.4365, 0.3132], device='cuda:0')
tensor([0.4365, 0.3132], device='cuda:1')
Epoch: 0008/0100 | train loss:0.41904139518737793
Epoch: 0008/0100 | train loss:0.2985176146030426
tensor([0.4190, 0.2985], device='cuda:0')
tensor([0.4190, 0.2985], device='cuda:1')
Epoch: 0009/0100 | train loss:0.4022897183895111
Epoch: 0009/0100 | train loss:0.28452268242836
tensor([0.4023, 0.2845], device='cuda:0')
tensor([0.4023, 0.2845], device='cuda:1')
Epoch: 0010/0100 | train loss:0.38661184906959534
Epoch: 0010/0100 | train loss:0.2712869644165039
tensor([0.3866, 0.2713], device='cuda:0')
tensor([0.3866, 0.2713], device='cuda:1')
Epoch: 0011/0100 | train loss:0.37144994735717773
Epoch: 0011/0100 | train loss:0.2587887942790985
tensor([0.3714, 0.2588], device='cuda:0')
tensor([0.3714, 0.2588], device='cuda:1')
Epoch: 0012/0100 | train loss:0.3572254776954651
Epoch: 0012/0100 | train loss:0.24688617885112762
tensor([0.3572, 0.2469], device='cuda:0')
tensor([0.3572, 0.2469], device='cuda:1')
Epoch: 0013/0100 | train loss:0.34366878867149353
Epoch: 0013/0100 | train loss:0.23560750484466553
tensor([0.3437, 0.2356], device='cuda:0')
tensor([0.3437, 0.2356], device='cuda:1')
Epoch: 0014/0100 | train loss:0.33070918917655945
Epoch: 0014/0100 | train loss:0.22490985691547394
tensor([0.3307, 0.2249], device='cuda:0')
tensor([0.3307, 0.2249], device='cuda:1')
Epoch: 0015/0100 | train loss:0.318371444940567
Epoch: 0015/0100 | train loss:0.21479550004005432
tensor([0.3184, 0.2148], device='cuda:0')
tensor([0.3184, 0.2148], device='cuda:1')
Epoch: 0016/0100 | train loss:0.30663591623306274
Epoch: 0016/0100 | train loss:0.20525796711444855
tensor([0.3066, 0.2053], device='cuda:0')
tensor([0.3066, 0.2053], device='cuda:1')
Epoch: 0017/0100 | train loss:0.2955937087535858
Epoch: 0017/0100 | train loss:0.19613352417945862
tensor([0.2956, 0.1961], device='cuda:0')
tensor([0.2956, 0.1961], device='cuda:1')
Epoch: 0018/0100 | train loss:0.2850213646888733
Epoch: 0018/0100 | train loss:0.18744778633117676
tensor([0.2850, 0.1874], device='cuda:0')
tensor([0.2850, 0.1874], device='cuda:1')
Epoch: 0019/0100 | train loss:0.27490052580833435
Epoch: 0019/0100 | train loss:0.17930081486701965
tensor([0.2749, 0.1793], device='cuda:0')
tensor([0.2749, 0.1793], device='cuda:1')
Epoch: 0020/0100 | train loss:0.265290230512619
Epoch: 0020/0100 | train loss:0.17152751982212067
tensor([0.2653, 0.1715], device='cuda:0')
tensor([0.2653, 0.1715], device='cuda:1')
Epoch: 0021/0100 | train loss:0.25619110465049744
Epoch: 0021/0100 | train loss:0.16420160233974457
tensor([0.2562, 0.1642], device='cuda:0')
tensor([0.2562, 0.1642], device='cuda:1')
Epoch: 0022/0100 | train loss:0.24748849868774414
Epoch: 0022/0100 | train loss:0.15718798339366913
tensor([0.2475, 0.1572], device='cuda:0')
tensor([0.2475, 0.1572], device='cuda:1')
Epoch: 0023/0100 | train loss:0.23922590911388397
Epoch: 0023/0100 | train loss:0.15056990087032318
tensor([0.2392, 0.1506], device='cuda:0')
tensor([0.2392, 0.1506], device='cuda:1')
Epoch: 0024/0100 | train loss:0.2313191443681717
Epoch: 0024/0100 | train loss:0.14431701600551605
tensor([0.2313, 0.1443], device='cuda:0')
tensor([0.2313, 0.1443], device='cuda:1')
Epoch: 0025/0100 | train loss:0.22383789718151093
Epoch: 0025/0100 | train loss:0.13829165697097778
tensor([0.2238, 0.1383], device='cuda:0')
tensor([0.2238, 0.1383], device='cuda:1')
Epoch: 0026/0100 | train loss:0.2166999876499176
Epoch: 0026/0100 | train loss:0.13270090520381927
tensor([0.2167, 0.1327], device='cuda:0')
tensor([0.2167, 0.1327], device='cuda:1')
Epoch: 0027/0100 | train loss:0.12735657393932343
Epoch: 0027/0100 | train loss:0.2099115401506424
tensor([0.2099, 0.1274], device='cuda:1')
tensor([0.2099, 0.1274], device='cuda:0')
Epoch: 0028/0100 | train loss:0.2034330815076828
Epoch: 0028/0100 | train loss:0.12219982594251633
tensor([0.2034, 0.1222], device='cuda:0')
tensor([0.2034, 0.1222], device='cuda:1')
Epoch: 0029/0100 | train loss:0.19724245369434357
Epoch: 0029/0100 | train loss:0.11739777773618698
tensor([0.1972, 0.1174], device='cuda:0')
tensor([0.1972, 0.1174], device='cuda:1')
Epoch: 0030/0100 | train loss:0.1913725584745407
Epoch: 0030/0100 | train loss:0.11280806362628937
tensor([0.1914, 0.1128], device='cuda:0')
tensor([0.1914, 0.1128], device='cuda:1')
Epoch: 0031/0100 | train loss:0.1856645792722702
Epoch: 0031/0100 | train loss:0.10841526836156845
tensor([0.1857, 0.1084], device='cuda:0')
tensor([0.1857, 0.1084], device='cuda:1')
Epoch: 0032/0100 | train loss:0.18032146990299225
Epoch: 0032/0100 | train loss:0.10436604171991348
tensor([0.1803, 0.1044], device='cuda:0')
tensor([0.1803, 0.1044], device='cuda:1')
Epoch: 0033/0100 | train loss:0.17524836957454681
Epoch: 0033/0100 | train loss:0.10045601427555084
tensor([0.1752, 0.1005], device='cuda:0')
tensor([0.1752, 0.1005], device='cuda:1')
Epoch: 0034/0100 | train loss:0.1704605370759964
Epoch: 0034/0100 | train loss:0.0966917872428894
tensor([0.1705, 0.0967], device='cuda:0')
tensor([0.1705, 0.0967], device='cuda:1')
Epoch: 0035/0100 | train loss:0.1658073514699936
Epoch: 0035/0100 | train loss:0.09323866665363312
tensor([0.1658, 0.0932], device='cuda:0')
tensor([0.1658, 0.0932], device='cuda:1')
Epoch: 0036/0100 | train loss:0.16137376427650452
Epoch: 0036/0100 | train loss:0.08982827514410019
tensor([0.1614, 0.0898], device='cuda:0')
tensor([0.1614, 0.0898], device='cuda:1')
Epoch: 0037/0100 | train loss:0.15720796585083008
Epoch: 0037/0100 | train loss:0.0867210254073143
tensor([0.1572, 0.0867], device='cuda:0')
tensor([0.1572, 0.0867], device='cuda:1')
Epoch: 0038/0100 | train loss:0.15312625467777252
Epoch: 0038/0100 | train loss:0.08372923731803894
tensor([0.1531, 0.0837], device='cuda:0')
tensor([0.1531, 0.0837], device='cuda:1')
Epoch: 0039/0100 | train loss:0.14925920963287354
Epoch: 0039/0100 | train loss:0.0807720348238945
tensor([0.1493, 0.0808], device='cuda:0')
tensor([0.1493, 0.0808], device='cuda:1')
Epoch: 0040/0100 | train loss:0.14571939408779144
Epoch: 0040/0100 | train loss:0.07814785093069077
tensor([0.1457, 0.0781], device='cuda:0')
tensor([0.1457, 0.0781], device='cuda:1')
Epoch: 0041/0100 | train loss:0.1421670764684677
Epoch: 0041/0100 | train loss:0.07556602358818054
tensor([0.1422, 0.0756], device='cuda:0')
tensor([0.1422, 0.0756], device='cuda:1')
Epoch: 0042/0100 | train loss:0.13886897265911102
Epoch: 0042/0100 | train loss:0.07304538041353226
tensor([0.1389, 0.0730], device='cuda:0')
tensor([0.1389, 0.0730], device='cuda:1')
Epoch: 0043/0100 | train loss:0.13570688664913177
Epoch: 0043/0100 | train loss:0.07073201984167099
tensor([0.1357, 0.0707], device='cuda:0')
tensor([0.1357, 0.0707], device='cuda:1')
Epoch: 0044/0100 | train loss:0.13255445659160614
Epoch: 0044/0100 | train loss:0.06854959577322006
tensor([0.1326, 0.0685], device='cuda:0')
tensor([0.1326, 0.0685], device='cuda:1')
Epoch: 0045/0100 | train loss:0.12969191372394562
Epoch: 0045/0100 | train loss:0.06643456220626831
tensor([0.1297, 0.0664], device='cuda:0')
tensor([0.1297, 0.0664], device='cuda:1')
Epoch: 0046/0100 | train loss:0.12693797051906586
Epoch: 0046/0100 | train loss:0.06441470235586166
tensor([0.1269, 0.0644], device='cuda:0')
tensor([0.1269, 0.0644], device='cuda:1')
Epoch: 0047/0100 | train loss:0.12435060739517212
Epoch: 0047/0100 | train loss:0.06256702542304993
tensor([0.1244, 0.0626], device='cuda:0')
tensor([0.1244, 0.0626], device='cuda:1')
Epoch: 0048/0100 | train loss:0.12184498459100723
Epoch: 0048/0100 | train loss:0.06076086685061455
tensor([0.1218, 0.0608], device='cuda:0')
tensor([0.1218, 0.0608], device='cuda:1')
Epoch: 0049/0100 | train loss:0.11948590725660324
Epoch: 0049/0100 | train loss:0.05909023433923721
tensor([0.1195, 0.0591], device='cuda:0')
tensor([0.1195, 0.0591], device='cuda:1')
Epoch: 0050/0100 | train loss:0.11719142645597458
Epoch: 0050/0100 | train loss:0.05748440697789192
tensor([0.1172, 0.0575], device='cuda:0')
tensor([0.1172, 0.0575], device='cuda:1')
Epoch: 0051/0100 | train loss:0.11490301042795181
Epoch: 0051/0100 | train loss:0.05596492439508438
tensor([0.1149, 0.0560], device='cuda:0')
tensor([0.1149, 0.0560], device='cuda:1')
Epoch: 0052/0100 | train loss:0.11284526437520981
Epoch: 0052/0100 | train loss:0.05452785640954971
tensor([0.1128, 0.0545], device='cuda:0')
tensor([0.1128, 0.0545], device='cuda:1')
Epoch: 0053/0100 | train loss:0.11080770939588547
Epoch: 0053/0100 | train loss:0.053089436143636703
tensor([0.1108, 0.0531], device='cuda:0')
tensor([0.1108, 0.0531], device='cuda:1')
Epoch: 0054/0100 | train loss:0.1088673397898674
Epoch: 0054/0100 | train loss:0.05177140235900879
tensor([0.1089, 0.0518], device='cuda:0')
tensor([0.1089, 0.0518], device='cuda:1')
Epoch: 0055/0100 | train loss:0.10703599452972412
Epoch: 0055/0100 | train loss:0.05052466318011284
tensor([0.1070, 0.0505], device='cuda:0')
tensor([0.1070, 0.0505], device='cuda:1')
Epoch: 0056/0100 | train loss:0.10530979931354523
Epoch: 0056/0100 | train loss:0.049302320927381516
tensor([0.1053, 0.0493], device='cuda:0')
tensor([0.1053, 0.0493], device='cuda:1')
Epoch: 0057/0100 | train loss:0.10361965000629425
Epoch: 0057/0100 | train loss:0.048224009573459625
tensor([0.1036, 0.0482], device='cuda:0')
tensor([0.1036, 0.0482], device='cuda:1')
Epoch: 0058/0100 | train loss:0.10195320099592209
Epoch: 0058/0100 | train loss:0.04709456115961075
tensor([0.1020, 0.0471], device='cuda:0')
tensor([0.1020, 0.0471], device='cuda:1')
Epoch: 0059/0100 | train loss:0.10047540813684464
Epoch: 0059/0100 | train loss:0.04614344984292984
tensor([0.1005, 0.0461], device='cuda:0')
tensor([0.1005, 0.0461], device='cuda:1')
Epoch: 0060/0100 | train loss:0.09898962825536728
Epoch: 0060/0100 | train loss:0.045158226042985916
tensor([0.0990, 0.0452], device='cuda:0')
tensor([0.0990, 0.0452], device='cuda:1')
Epoch: 0061/0100 | train loss:0.097608782351017
Epoch: 0061/0100 | train loss:0.044237129390239716
tensor([0.0976, 0.0442], device='cuda:0')
tensor([0.0976, 0.0442], device='cuda:1')
Epoch: 0062/0100 | train loss:0.09622994810342789
Epoch: 0062/0100 | train loss:0.043375153094530106
tensor([0.0962, 0.0434], device='cuda:0')
tensor([0.0962, 0.0434], device='cuda:1')
Epoch: 0063/0100 | train loss:0.09495609253644943
Epoch: 0063/0100 | train loss:0.04254027456045151
tensor([0.0950, 0.0425], device='cuda:0')
tensor([0.0950, 0.0425], device='cuda:1')
Epoch: 0064/0100 | train loss:0.04172029718756676
Epoch: 0064/0100 | train loss:0.09371034801006317
tensor([0.0937, 0.0417], device='cuda:1')
tensor([0.0937, 0.0417], device='cuda:0')
Epoch: 0065/0100 | train loss:0.04094156622886658
Epoch: 0065/0100 | train loss:0.09246573597192764
tensor([0.0925, 0.0409], device='cuda:0')
tensor([0.0925, 0.0409], device='cuda:1')
Epoch: 0066/0100 | train loss:0.09130342304706573
Epoch: 0066/0100 | train loss:0.040253669023513794
tensor([0.0913, 0.0403], device='cuda:0')
tensor([0.0913, 0.0403], device='cuda:1')
Epoch: 0067/0100 | train loss:0.09026143699884415
Epoch: 0067/0100 | train loss:0.03958689793944359
tensor([0.0903, 0.0396], device='cuda:0')
tensor([0.0903, 0.0396], device='cuda:1')
Epoch: 0068/0100 | train loss:0.08916200697422028
Epoch: 0068/0100 | train loss:0.03885350748896599
tensor([0.0892, 0.0389], device='cuda:0')
tensor([0.0892, 0.0389], device='cuda:1')
Epoch: 0069/0100 | train loss:0.08816101402044296
Epoch: 0069/0100 | train loss:0.03830384090542793
tensor([0.0882, 0.0383], device='cuda:0')
tensor([0.0882, 0.0383], device='cuda:1')
Epoch: 0070/0100 | train loss:0.08718284964561462
Epoch: 0070/0100 | train loss:0.03767556697130203
tensor([0.0872, 0.0377], device='cuda:0')
tensor([0.0872, 0.0377], device='cuda:1')
Epoch: 0071/0100 | train loss:0.08624932169914246
Epoch: 0071/0100 | train loss:0.03716084733605385
tensor([0.0862, 0.0372], device='cuda:0')
tensor([0.0862, 0.0372], device='cuda:1')
Epoch: 0072/0100 | train loss:0.08536970615386963
Epoch: 0072/0100 | train loss:0.03657805919647217
tensor([0.0854, 0.0366], device='cuda:0')
tensor([0.0854, 0.0366], device='cuda:1')
Epoch: 0073/0100 | train loss:0.08444425463676453
Epoch: 0073/0100 | train loss:0.036069512367248535
tensor([0.0844, 0.0361], device='cuda:0')
tensor([0.0844, 0.0361], device='cuda:1')
Epoch: 0074/0100 | train loss:0.08365066349506378
Epoch: 0074/0100 | train loss:0.035561252385377884
tensor([0.0837, 0.0356], device='cuda:0')
tensor([0.0837, 0.0356], device='cuda:1')
Epoch: 0075/0100 | train loss:0.0828193947672844
Epoch: 0075/0100 | train loss:0.03512110188603401
tensor([0.0828, 0.0351], device='cuda:0')
tensor([0.0828, 0.0351], device='cuda:1')
Epoch: 0076/0100 | train loss:0.08206731826066971
Epoch: 0076/0100 | train loss:0.03470907360315323
tensor([0.0821, 0.0347], device='cuda:0')
tensor([0.0821, 0.0347], device='cuda:1')
Epoch: 0077/0100 | train loss:0.08136867731809616
Epoch: 0077/0100 | train loss:0.03429228812456131
tensor([0.0814, 0.0343], device='cuda:0')
tensor([0.0814, 0.0343], device='cuda:1')
Epoch: 0078/0100 | train loss:0.08061014115810394
Epoch: 0078/0100 | train loss:0.03388326242566109
tensor([0.0806, 0.0339], device='cuda:0')
tensor([0.0806, 0.0339], device='cuda:1')
Epoch: 0079/0100 | train loss:0.07996807247400284
Epoch: 0079/0100 | train loss:0.0334811694920063
tensor([0.0800, 0.0335], device='cuda:0')
tensor([0.0800, 0.0335], device='cuda:1')
Epoch: 0080/0100 | train loss:0.07923366874456406
Epoch: 0080/0100 | train loss:0.03312436491250992
tensor([0.0792, 0.0331], device='cuda:0')
tensor([0.0792, 0.0331], device='cuda:1')
Epoch: 0081/0100 | train loss:0.07861354202032089
Epoch: 0081/0100 | train loss:0.03278031200170517
tensor([0.0786, 0.0328], device='cuda:0')
tensor([0.0786, 0.0328], device='cuda:1')
Epoch: 0082/0100 | train loss:0.07789915800094604
Epoch: 0082/0100 | train loss:0.03244069963693619
tensor([0.0779, 0.0324], device='cuda:0')
tensor([0.0779, 0.0324], device='cuda:1')
Epoch: 0083/0100 | train loss:0.07733096927404404
Epoch: 0083/0100 | train loss:0.03207029029726982
tensor([0.0773, 0.0321], device='cuda:0')
tensor([0.0773, 0.0321], device='cuda:1')
Epoch: 0084/0100 | train loss:0.07673352211713791
Epoch: 0084/0100 | train loss:0.031769514083862305
tensor([0.0767, 0.0318], device='cuda:0')
tensor([0.0767, 0.0318], device='cuda:1')
Epoch: 0085/0100 | train loss:0.07619936764240265
Epoch: 0085/0100 | train loss:0.031524963676929474
tensor([0.0762, 0.0315], device='cuda:0')
tensor([0.0762, 0.0315], device='cuda:1')
Epoch: 0086/0100 | train loss:0.07563362270593643
Epoch: 0086/0100 | train loss:0.03119492344558239
tensor([0.0756, 0.0312], device='cuda:0')
tensor([0.0756, 0.0312], device='cuda:1')
Epoch: 0087/0100 | train loss:0.0750502347946167
Epoch: 0087/0100 | train loss:0.03095475398004055
tensor([0.0751, 0.0310], device='cuda:0')
tensor([0.0751, 0.0310], device='cuda:1')
Epoch: 0088/0100 | train loss:0.0746132880449295
Epoch: 0088/0100 | train loss:0.030701685696840286
tensor([0.0746, 0.0307], device='cuda:0')
tensor([0.0746, 0.0307], device='cuda:1')
Epoch: 0089/0100 | train loss:0.07409549504518509
Epoch: 0089/0100 | train loss:0.030368996784090996
tensor([0.0741, 0.0304], device='cuda:0')
tensor([0.0741, 0.0304], device='cuda:1')
Epoch: 0090/0100 | train loss:0.0735851600766182
Epoch: 0090/0100 | train loss:0.03020581416785717
tensor([0.0736, 0.0302], device='cuda:0')
tensor([0.0736, 0.0302], device='cuda:1')
Epoch: 0091/0100 | train loss:0.07305028289556503
Epoch: 0091/0100 | train loss:0.029953384771943092
tensor([0.0731, 0.0300], device='cuda:0')
tensor([0.0731, 0.0300], device='cuda:1')
Epoch: 0092/0100 | train loss:0.07270056009292603
Epoch: 0092/0100 | train loss:0.029726726934313774
tensor([0.0727, 0.0297], device='cuda:0')
tensor([0.0727, 0.0297], device='cuda:1')
Epoch: 0093/0100 | train loss:0.07219361513853073
Epoch: 0093/0100 | train loss:0.02954575978219509
tensor([0.0722, 0.0295], device='cuda:0')
tensor([0.0722, 0.0295], device='cuda:1')
Epoch: 0094/0100 | train loss:0.07180915772914886
Epoch: 0094/0100 | train loss:0.02932337485253811
tensor([0.0718, 0.0293], device='cuda:0')
tensor([0.0718, 0.0293], device='cuda:1')
Epoch: 0095/0100 | train loss:0.07139516621828079
Epoch: 0095/0100 | train loss:0.029103577136993408
tensor([0.0714, 0.0291], device='cuda:0')
tensor([0.0714, 0.0291], device='cuda:1')
Epoch: 0096/0100 | train loss:0.07094169408082962
Epoch: 0096/0100 | train loss:0.02893088571727276
tensor([0.0709, 0.0289], device='cuda:0')
tensor([0.0709, 0.0289], device='cuda:1')
Epoch: 0097/0100 | train loss:0.028796857222914696
Epoch: 0097/0100 | train loss:0.07059731334447861
tensor([0.0706, 0.0288], device='cuda:0')
tensor([0.0706, 0.0288], device='cuda:1')
Epoch: 0098/0100 | train loss:0.028585290536284447
Epoch: 0098/0100 | train loss:0.0701548159122467
tensor([0.0702, 0.0286], device='cuda:0')
tensor([0.0702, 0.0286], device='cuda:1')
Epoch: 0099/0100 | train loss:0.06985291093587875
Epoch: 0099/0100 | train loss:0.028429213911294937
tensor([0.0699, 0.0284], device='cuda:0')
tensor([0.0699, 0.0284], device='cuda:1')
Epoch: 0100/0100 | train loss:0.06947710365056992
Epoch: 0100/0100 | train loss:0.028299672529101372
tensor([0.0695, 0.0283], device='cuda:0')
tensor([0.0695, 0.0283], device='cuda:1')
以上fabric对应lightning2.1版本,该工具还在开发中,后期会有其他功能。
2、真实数据的处理
2.1 下载数据
import os
from torchvision.datasets import MNIST
import torch
# 设置下载路径为当前目录下的mnist_data文件夹
download_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "mnist_data")
os.makedirs(download_path, exist_ok=True)
print(f"正在下载MNIST数据集到 {download_path}...")
# 下载训练集和测试集
train_dataset = MNIST(download_path, train=True, download=True)
test_dataset = MNIST(download_path, train=False, download=True)
print("MNIST数据集下载完成。")
print(f"文件位置: {download_path}")
print("文件列表:")
for root, dirs, files in os.walk(download_path):
for file in files:
print(os.path.join(root, file))
2.2 单机多卡模型训练
将数据保存到代码位置,然后进行训练
# 导入所需的库
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset # 导入 Dataset 以备后用
from torchvision import datasets, transforms # 用于加载 MNIST 数据集
from lightning.fabric import Fabric, seed_everything # 从 lightning.fabric 导入 Fabric 和 seed_everything
from torchmetrics.classification import Accuracy # 用于计算准确率
import time # 用于计时
# 定义简单的卷积神经网络模型 (保持不变)
class SimpleConvNet(nn.Module):
def __init__(self):
super(SimpleConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# 训练一个 epoch 的函数 (保持不变)
def train_epoch(fabric: Fabric, model: nn.Module, train_loader: DataLoader, optimizer: optim.Optimizer, epoch: int):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
fabric.backward(loss) # 使用 fabric.backward
optimizer.step()
if batch_idx % 100 == 0:
fabric.print(f'训练 Epoch: {epoch} [{batch_idx * len(data) * fabric.world_size}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\t损失: {loss.item():.6f}')
# 测试一个 epoch 的函数 (保持不变)
def test_epoch(fabric: Fabric, model: nn.Module, test_loader: DataLoader):
model.eval()
test_acc = Accuracy(task="multiclass", num_classes=10).to(fabric.device)
total_loss_tensor = torch.tensor(0.0, device=fabric.device)
with torch.no_grad():
for data, target in test_loader:
output = model(data)
batch_loss = F.nll_loss(output, target, reduction='sum')
total_loss_tensor += batch_loss
test_acc.update(output, target)
gathered_losses = fabric.all_gather(total_loss_tensor)
avg_loss = gathered_losses.sum() / len(test_loader.dataset)
final_acc = test_acc.compute()
total_samples = len(test_loader.dataset)
correct_samples = final_acc * total_samples
fabric.print(f'\n测试集: 平均损失: {avg_loss:.4f}, '
f'准确率: {final_acc*100:.0f}% ({int(correct_samples)}/{total_samples})\n')
test_acc.reset()
# 主执行函数
def main(num_devices_to_use: int):
# --- 超参数设置 ---
batch_size_per_device = 64
epochs = 5
lr = 1.0
seed = 42
num_workers_per_loader = 2 # 保持 > 0 以测试 worker 清理
# --- Fabric 初始化 ---
# 添加 precision='16-mixed' 像 xx.py 一样
fabric = Fabric(accelerator="cuda",
devices=num_devices_to_use,
strategy="ddp",
precision='16-mixed') # <-- 添加混合精度
fabric.launch()
# --- 可复现性 ---
seed_everything(seed + fabric.global_rank)
# --- 数据准备 ---
data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "mnist_data")
fabric.print(f"数据路径: {data_path}")
is_main_process = fabric.is_global_zero
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
if is_main_process:
fabric.print(f"Rank {fabric.global_rank}: 检查并下载 MNIST 数据集...")
datasets.MNIST(data_path, train=True, download=True)
datasets.MNIST(data_path, train=False, download=True)
fabric.print(f"Rank {fabric.global_rank}: 数据集下载完成或已存在。")
fabric.barrier()
fabric.print(f"Rank {fabric.global_rank}: 继续加载数据集...")
train_dataset = datasets.MNIST(
data_path, train=True, download=False, transform=transform
)
test_dataset = datasets.MNIST(
data_path, train=False, download=False, transform=transform
)
# --- 创建数据加载器 ---
# *** 关键修改:显式设置 persistent_workers=False ***
train_loader = DataLoader(
train_dataset,
batch_size=batch_size_per_device,
num_workers=num_workers_per_loader,
shuffle=True,
persistent_workers=True if num_workers_per_loader > 0 else False, # <-- 强制禁用持久化 workers
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size_per_device * 2,
num_workers=num_workers_per_loader,
shuffle=False,
persistent_workers=True if num_workers_per_loader > 0 else False, # <-- 强制禁用持久化 workers
pin_memory=True
)
fabric.print(f"进程 {fabric.global_rank}/{fabric.world_size} 使用设备: {fabric.device}")
fabric.print(f"训练数据集大小: {len(train_dataset)}")
fabric.print(f"测试数据集大小: {len(test_dataset)}")
fabric.print(f"每个设备的 Batch Size: {batch_size_per_device}")
fabric.print(f"有效 Batch Size: {batch_size_per_device * fabric.world_size}")
# --- 使用 Fabric 设置模型、优化器和数据加载器 ---
train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader)
# 实例化模型和优化器
model = SimpleConvNet()
# 注意:如果模型有 BatchNorm,并且 world_size > 1, 在 fabric.setup 之前转换:
# if fabric.world_size > 1:
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
optimizer = optim.Adadelta(model.parameters(), lr=lr)
model, optimizer = fabric.setup(model, optimizer) # setup 会处理模型移动到设备和 DDP 包装
# --- 训练循环 ---
fabric.print(f"\n在 {fabric.world_size} 个 GPU 上开始训练...")
start_time = time.time()
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train_epoch(fabric, model, train_loader, optimizer, epoch)
# 在 epoch 内部保留屏障是可选的,但通常有助于确保测试前训练已在所有设备完成
fabric.barrier()
test_epoch(fabric, model, test_loader)
# 保留 epoch 后的屏障也是可选的
fabric.barrier()
epoch_time = time.time() - epoch_start_time
fabric.print(f"Epoch {epoch} 在 {epoch_time:.2f} 秒内完成。")
total_time = time.time() - start_time
fabric.print(f"\n训练在 {total_time:.2f} 秒内完成。")
# # --- 保存模型 (使用手动 torch.save state_dict) ---
# if fabric.is_global_zero:
# save_path = "mnist_fabric_model.pt"
# fabric.print(f"Rank {fabric.global_rank}: 获取模型 state_dict...")
# # 获取 state_dict 并在 CPU 上保存是推荐做法
# state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
# fabric.print(f"Rank {fabric.global_rank}: 开始使用 torch.save 保存状态字典到 {save_path}...")
# try:
# torch.save(state_dict, save_path) # 直接使用 torch.save
# print(f"\n模型状态字典已保存到 {save_path}")
# fabric.print(f"Rank {fabric.global_rank}: torch.save 保存完毕。")
# except Exception as e:
# fabric.print(f"Rank {fabric.global_rank}: 保存模型时发生错误: {e}")
##另一种方法
save_path = "mnist_fabric_model_1.pt"
state = {"model": model}
fabric.save(save_path, state)
fabric.print(f"\n模型状态已保存到 {save_path}")
# --- 移除末尾的最终屏障 ---
# 让 fabric.launch() 的上下文管理器负责退出时的同步和清理
fabric.print(f"Rank {fabric.global_rank}: main 函数即将结束,没有显式的最终屏障。")
# fabric.barrier() # <-- 已移除
# 脚本入口点 (保持不变)
if __name__ == "__main__":
num_devices = 2
if not torch.cuda.is_available():
print("错误:未检测到 CUDA。此脚本需要 GPU 运行。")
elif torch.cuda.device_count() < num_devices:
print(f"错误:需要 {num_devices} 个 GPU,但只找到 {torch.cuda.device_count()} 个。")
else:
print(f"检测到 {torch.cuda.device_count()} 个 GPU。将使用指定的 {num_devices} 个 GPU。")
main(num_devices_to_use=num_devices)
以上代码提供两种的模型保存,切记不要把fabric.save放到fabric.is_global_zero下,这样会卡住不动的。好了,使用fabric进行单机多卡进行训练还是很方便的。