配置好mmsegmentation 和mmcv (见上一篇博客)
这里以swin transform 为例,这里新建my_data目录复制进去所需要的文件
1. 新建my_data->swin
将所有的实验训练的config文件全部放在swin文件下
根据想要训练的网络,在configs目录选取网络模型及输入尺寸,迭代次数等
(1)这里训练的是swin,所以选取Swin_Transformer/configs/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k.py
将其放在新建my_data->swin文件夹下。
upernet_swin_tiny_patch4_window7_512x512_160k_ade20k.py文件中是这样的。所以再将
'../_base_/models/upernet_swin.py', '../_base_/datasets/ade20k.py', '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 复制过来
(2)选取Swin_Transformer/configs/_base_/models/upernet_swin.py文件复制到 新建my_data->swin文件夹下
(3)因为我这里使用的是voc格式数据集,所以选取Swin_Transformer/configs/_base_/datasets/pascal_voc12.py文件复制到
新建my_data->swin文件夹下
(4)将Swin_Transformer/configs/_base_/default_runtime.py文件复制到
新建my_data->swin文件夹下
(5)将Swin_Transformer/configs/_base_/schedules/schedule_160k.py文件复制到新建my_data->swin文件夹下
(6)新建一个空白__init__.py文件
my_data->swin文件夹下,如图
2.按照需求修改训练config (即新建my_data->swin文件夹下的文件)
(1)修改pascal_voc12.py,因为使用的是自定义数据集
一、创建一个my_datasets.py用这个代替pascal_voc12.py,并修改dataset_type = 'MyDataset'
二、此时新建data文件夹,将自己的数据集(VOC格式),放在其中,格式如下,并且图片为jpg,mask为png。
三、在mmseg/datasets新建一个文件my_dataset.py仿照voc.py修改内容
四、在mmseg/datasets/__init__.py
中把自己的数据集添加进去:
(2)修改upernet_swin_tiny_patch4_window7_512x512_160k_ade20k.py,(因为位置发生改变)
这里还需要修改num_classes 为自己的类别+1 算上背景 !!!!!
Swin_Transformer/my_data/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k.py
(3)修改schedule_160k.py
【mmsegmentation】工程--使用小技巧_mmseg 使用epoch训练-CSDN博客
修改最大迭代次数,以及模型保存的间隔,评估的间隔
mmseg默认训练的次数是按照inter去计算的,swin中160000个inter,每16000次inter进行一次模型验证,并保存一次模型。
这里修改为epoch来验证和保存模型,修改config
另外 保存bestmodel 在mmseg/apis/train.py,找到runner.register_hook这里,修改原有的钩子定义
(4)修改upernet_swin.py文件 my_data/swin/upernet_swin.py
修改norm_cfg SyncBN->BN 单卡用BN
修改num_classes 2处
(5)修改
3. 训练修改
(1)修改 /hy-tmp/Swin_Transformer/tools/train.py
parser.add_argument('config', help='Swin_Transformer/my_data/swin/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k.py')
parser.add_argument('--work-dir', help='Swin_Transformer/my_data/swin_logs') #日志保存
parser.add_argument( '--load-from', help='Swin_Transformer/weights/swin_tiny_patch4_window7_224.pth')
(2)单卡训练,修改
4. 进入到tools路径下 python train.py 模型训练