模板代码概述
1. 数据集函数
class MyDataset(Dataset):
def __init__(self, img_id_list, IMG_SIZE, mode='train', augmentation=False):
"""传参,定义参数
1. 数据集列表,
- 本地数据,文件名/图片名
- API,图片ID
2. 图片读取尺寸
3. 训练模式or推理模式
4. 是否做Data augmentation
...
"""
pass
def __getitem__(self, idx):
"""读取下一个样本
1. 读取本地图片,或读API接口获取base64格式图片
2. 预处理, 如变换图片尺寸
3. 若训练集,读取Mask图片
4. Data augmentation
"""
pass
def __len__(self):
"""定义样本个数
"""
pass
def prepare_trainset():
"""
1. 切分数据集,训练集/验证集
2. 定义MyDataset训练集、MyDataset验证集
3. 定义Pytorch的DataLoader
train_dl = DataLoader(
train_dataset,
batch_size=16,
shuffle=True,
#sampler=sampler,
num_workers=8,
drop_last=True
)
val_dl = DataLoader(
val_dataset,
batch_size=16,
shuffle=False,
#sampler=sampler,
num_workers=8,
drop_last=True
)
"""
pass
2. Utils函数
3. 分割的评估函数
4. 训练脚本
def run_training():
"""training pipline
1. 读取network
- 加载预训练模型
- 定义训练全部层的参数/哪几层参数
- 定义学习率/为每一层定义学习率
- 定义优化函数optimizer、学习率变化方案scheduler
-
2. 训练N_EPOCH次迭代,每一个迭代内:
- 用DataLoader循环读取训练集上每一个batch数据(N个图片、N个mask)
- 将N个图片传入network,输出模型最后一层的预测(sigmoid概率)
- 计算这个batch上的loss、metric,并存下来
- 反向传播,更新参数(.backward())(是否梯度累加)
- 计算所有batch上loss、metric的总体均值,代表这个EPOCH
- 用DataLoader循环读取验证集上每一个batch数据,与以上操作相似,计算验证集上的loss、metric,用于决定哪一个EPOCH停止训练
- 更新logging、保存checkpoint
"""
pass
5. Unet介绍