环境配置
! pip install -U openmim
! mim install mmengine
! mim install "mmcv>=2.0.0"
! git clone -b main https://github.com/open-mmlab/mmsegmentation.git
%cd mmsegmentation
! pip install -v -e .
验证安装
! mim download mmsegmentation --config pspnet_r50-d8_4xb2-40k_cityscapes-512x1024 --dest .
! python demo/image_demo.py demo/demo.png configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth --device cuda:0 --out-file result.jpg
准备数据集
! wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/watermelon/Watermelon87_Semantic_Seg_Mask.zip
import zipfile
zip_path = '/content/mmsegmentation/Watermelon87_Semantic_Seg_Mask.zip'
extract_path = './data/'
zip_file = zipfile.ZipFile( zip_path, 'r' )
zip_file.extractall( extract_path)
zip_file.close( )
修改cfg文件
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@DATASETS.register_module( )
class WatermelonDataset( BaseSegDataset) :
"" "Cityscapes dataset.
The `` img_suffix` ` is fixed to '_leftImg8bit.png' and ` ` seg_map_suffix` ` is
fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
"" "
METAINFO = dict(
classes = ( 'red_flesh' , 'shell' , 'white_rind' , 'black_seed' , 'white_seeds' ) ,
palette = [ [ 128 , 64 , 128 ] , [ 244 , 35 , 232 ] , [ 70 , 70 , 70 ] , [ 102 , 102 , 156 ] ,
[ 190 , 153 , 153 ] ] )
def __init__( self,
img_suffix = '.jpg' ,
seg_map_suffix = '.png' ,
**kwargs) -> None:
super( ) .__init__(
img_suffix = img_suffix, seg_map_suffix = seg_map_suffix, **kwargs)
norm_cfg = dict( type= 'BN' , requires_grad = True)
data_preprocessor = dict(
type = 'SegDataPreProcessor' ,
mean = [ 123.675 , 116.28 , 103.53 ] ,
std = [ 58.395 , 57.12 , 57.375 ] ,
bgr_to_rgb = True,
pad_val = 0 ,
seg_pad_val = 255 ,
size = ( 256 , 256 ))
model = dict(
type = 'EncoderDecoder' ,
data_preprocessor = dict(
type = 'SegDataPreProcessor' ,
mean = [ 123.675 , 116.28 , 103.53 ] ,
std = [ 58.395 , 57.12 , 57.375 ] ,
bgr_to_rgb = True,
pad_val = 0 ,
seg_pad_val = 255 ,
size = ( 256 , 256 )) ,
pretrained = 'open-mmlab://resnet50_v1c' ,
backbone = dict(
type = 'ResNetV1c' ,
depth = 50 ,
num_stages = 4 ,
out_indices = ( 0 , 1 , 2 , 3 ) ,
dilations = ( 1 , 1 , 2 , 4 ) ,
strides = ( 1 , 2 , 1 , 1 ) ,
norm_cfg = dict( type= 'BN' , requires_grad = True) ,
norm_eval = False,
style = 'pytorch' ,
contract_dilation = True) ,
decode_head = dict(
type = 'PSPHead' ,
in_channels = 2048 ,
in_index = 3 ,
channels = 512 ,
pool_scales = ( 1 , 2 , 3 , 6 ) ,
dropout_ratio = 0.1 ,
num_classes = 6 ,
norm_cfg = dict( type= 'BN' , requires_grad = True) ,
align_corners = False,
loss_decode = dict(
type = 'CrossEntropyLoss' , use_sigmoid = False, loss_weight = 1.0 )) ,
auxiliary_head = dict(
type = 'FCNHead' ,
in_channels = 1024 ,
in_index = 2 ,
channels = 256 ,
num_convs = 1 ,
concat_input = False,
dropout_ratio = 0.1 ,
num_classes = 6 ,
norm_cfg = dict( type= 'BN' , requires_grad = True) ,
align_corners = False,
loss_decode = dict(
type = 'CrossEntropyLoss' , use_sigmoid = False, loss_weight = 0.4 )) ,
train_cfg = dict( ) ,
test_cfg = dict( mode= 'whole' ))
dataset_type = 'WatermelonDataset'
data_root = 'data/Watermelon87_Semantic_Seg_Mask/'
crop_size = ( 256 , 256 )
train_pipeline = [
dict( type= 'LoadImageFromFile' ) ,
dict( type= 'LoadAnnotations' ) ,
dict(
type = 'RandomResize' ,
scale = ( 2048 , 1024 ) ,
ratio_range = ( 0.5 , 2.0 ) ,
keep_ratio = True) ,
dict( type= 'RandomCrop' , crop_size = ( 64 , 64 ) , cat_max_ratio = 0.75 ) ,
dict( type= 'RandomFlip' , prob = 0.5 ) ,
dict( type= 'PhotoMetricDistortion' ) ,
dict( type= 'PackSegInputs' )
]
test_pipeline = [
dict( type= 'LoadImageFromFile' ) ,
dict( type= 'Resize' , scale = ( 2048 , 1024 ) , keep_ratio = True) ,
dict( type= 'LoadAnnotations' ) ,
dict( type= 'PackSegInputs' )
]
img_ratios = [ 0.5 , 0.75 , 1.0 , 1.25 , 1.5 , 1.75 ]
tta_pipeline = [
dict( type= 'LoadImageFromFile' , file_client_args = dict( backend= 'disk' )) ,
dict(
type = 'TestTimeAug' ,
transforms = [ [ {
'type' : 'Resize' ,
'scale_factor' : 0.5 ,
'keep_ratio' : True
} , {
'type' : 'Resize' ,
'scale_factor' : 0.75 ,
'keep_ratio' : True
} , {
'type' : 'Resize' ,
'scale_factor' : 1.0 ,
'keep_ratio' : True
} , {
'type' : 'Resize' ,
'scale_factor' : 1.25 ,
'keep_ratio' : True
} , {
'type' : 'Resize' ,
'scale_factor' : 1.5 ,
'keep_ratio' : True
} , {
'type' : 'Resize' ,
'scale_factor' : 1.75 ,
'keep_ratio' : True
} ] ,
[ {
'type' : 'RandomFlip' ,
'prob' : 0.0 ,
'direction' : 'horizontal'
} , {
'type' : 'RandomFlip' ,
'prob' : 1.0 ,
'direction' : 'horizontal'
} ] , [ {
'type' : 'LoadAnnotations'
} ] , [ {
'type' : 'PackSegInputs'
} ] ] )
]
train_dataloader = dict(
batch_size = 8 ,
num_workers = 2 ,
persistent_workers = True,
sampler = dict( type= 'InfiniteSampler' , shuffle = True) ,
dataset = dict(
type = 'WatermelonDataset' ,
data_root = 'data/Watermelon87_Semantic_Seg_Mask/' ,
data_prefix = dict(
img_path = 'img_dir/train' , seg_map_path = 'ann_dir/train' ) ,
pipeline = [
dict( type= 'LoadImageFromFile' ) ,
dict( type= 'LoadAnnotations' ) ,
dict(
type = 'RandomResize' ,
scale = ( 2048 , 1024 ) ,
ratio_range = ( 0.5 , 2.0 ) ,
keep_ratio = True) ,
dict( type= 'RandomCrop' , crop_size = ( 64 , 64 ) , cat_max_ratio = 0.75 ) ,
dict( type= 'RandomFlip' , prob = 0.5 ) ,
dict( type= 'PhotoMetricDistortion' ) ,
dict( type= 'PackSegInputs' )
] ))
val_dataloader = dict(
batch_size = 1 ,
num_workers = 4 ,
persistent_workers = True,
sampler = dict( type= 'DefaultSampler' , shuffle = False) ,
dataset = dict(
type = 'WatermelonDataset' ,
data_root = 'data/Watermelon87_Semantic_Seg_Mask/' ,
data_prefix = dict( img_path= 'img_dir/val' , seg_map_path = 'ann_dir/val' ) ,
pipeline = [
dict( type= 'LoadImageFromFile' ) ,
dict( type= 'Resize' , scale = ( 2048 , 1024 ) , keep_ratio = True) ,
dict( type= 'LoadAnnotations' ) ,
dict( type= 'PackSegInputs' )
] ))
test_dataloader = dict(
batch_size = 1 ,
num_workers = 4 ,
persistent_workers = True,
sampler = dict( type= 'DefaultSampler' , shuffle = False) ,
dataset = dict(
type = 'WatermelonDataset' ,
data_root = 'data/Watermelon87_Semantic_Seg_Mask/' ,
data_prefix = dict( img_path= 'img_dir/val' , seg_map_path = 'ann_dir/val' ) ,
pipeline = [
dict( type= 'LoadImageFromFile' ) ,
dict( type= 'Resize' , scale = ( 2048 , 1024 ) , keep_ratio = True) ,
dict( type= 'LoadAnnotations' ) ,
dict( type= 'PackSegInputs' )
] ))
val_evaluator = dict( type= 'IoUMetric' , iou_metrics = [ 'mIoU' ] )
test_evaluator = dict( type= 'IoUMetric' , iou_metrics = [ 'mIoU' ] )
default_scope = 'mmseg'
env_cfg = dict(
cudnn_benchmark = True,
mp_cfg = dict( mp_start_method= 'fork' , opencv_num_threads = 0 ) ,
dist_cfg = dict( backend= 'nccl' ))
vis_backends = [ dict( type= 'LocalVisBackend' ) ]
visualizer = dict(
type = 'SegLocalVisualizer' ,
vis_backends = [ dict( type= 'LocalVisBackend' ) ] ,
name = 'visualizer' )
log_processor = dict( by_epoch= False)
log_level = 'INFO'
load_from = None
resume = False
tta_model = dict( type= 'SegTTAModel' )
optimizer = dict( type= 'SGD' , lr = 0.01 , momentum = 0.9 , weight_decay = 0.0005 )
optim_wrapper = dict(
type = 'OptimWrapper' ,
optimizer = dict( type= 'SGD' , lr = 0.01 , momentum = 0.9 , weight_decay = 0.0005 ) ,
clip_grad = None)
param_scheduler = [
dict(
type = 'PolyLR' ,
eta_min = 0.0001 ,
power = 0.9 ,
begin = 0 ,
end = 40000 ,
by_epoch = False)
]
train_cfg = dict( type= 'IterBasedTrainLoop' , max_iters = 3000 , val_interval = 400 )
val_cfg = dict( type= 'ValLoop' )
test_cfg = dict( type= 'TestLoop' )
default_hooks = dict(
timer = dict( type= 'IterTimerHook' ) ,
logger = dict( type= 'LoggerHook' , interval = 100 , log_metric_by_epoch = False) ,
param_scheduler = dict( type= 'ParamSchedulerHook' ) ,
checkpoint = dict( type= 'CheckpointHook' , by_epoch = False, interval = 1500 ) ,
sampler_seed = dict( type= 'DistSamplerSeedHook' ) ,
visualization = dict( type= 'SegVisualizationHook' ))
randomness = dict( seed= 0 )
work_dir = './work_dirs/DubaiDataset'
模型训练
! python tools/train.py /content/mmsegmentation/cfg.py
模型测试
import requests
from PIL import Image
from io import BytesIO
url = "https://p9.itc.cn/q_70/images03/20230614/d7c96d1cf185436098db0bde886de974.jpeg"
response = requests.get( url)
img = Image.open( BytesIO( response.content))
img.save( "image.jpg" )
! python demo/image_demo.py image.jpg /content/mmsegmentation/cfg.py /content/mmsegmentation/work_dirs/DubaiDataset/iter_17.pth --device cuda:0 --out-file result.jpg
import matplotlib.pyplot as plt
from PIL import Image
img = Image.open( "/content/mmsegmentation/result.jpg" )
fig, ax = plt.subplots( )
ax.imshow( img)
ax.axis( 'off' )
plt.show( )