显存不够用:
1. 降低batchsize
2. 降低图像输入尺寸
3. 半精度训练,将数据由float32改为float16
模型保存
1. 单卡保存单卡加载
2. 单卡保存多卡加载
3. 多卡保存单卡加载
4. 多卡保存多卡加载
自定义损失函数
以类的方式定义
class DiceLoss(nn.Module):
def __init__(self,weight=None,size_average=True):
super(DiceLoss,self).__init__()
def forward(self,inputs,targets,smooth=1):
inputs = F.sigmoid(inputs)
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
return 1 - dice
# 使用方法
criterion = DiceLoss()
loss = criterion(input,targets)
模型微调
使用timm库,里面保存了很多的STOA模型,pip install timm即可