mmsegmentation试用

1.安装

环境最好是Ubuntu(Linux),Windows会遇到一个比较难解决的问题。

https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/get_started.md#installation

 

打开链接后,使用第一种方式,后面要直接用,其实也可以同时pip install mmsegmentation,下面是我创建的虚拟环境,供大家参考

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main    defaults
_openmp_mutex             5.1                       1_gnu    defaults
addict                    2.4.0                    pypi_0    pypi
anykeystore               0.2                      pypi_0    pypi
apex                      0.1                      pypi_0    pypi
blas                      1.0                         mkl    defaults
ca-certificates           2022.4.26            h06a4308_0    defaults
certifi                   2022.6.15        py38h06a4308_0    defaults
charset-normalizer        2.0.12                   pypi_0    pypi
click                     8.1.3                    pypi_0    pypi
colorama                  0.4.5                    pypi_0    pypi
commonmark                0.9.1                    pypi_0    pypi
cryptacular               1.6.2                    pypi_0    pypi
cycler                    0.11.0                   pypi_0    pypi
defusedxml                0.7.1                    pypi_0    pypi
fonttools                 4.33.3                   pypi_0    pypi
greenlet                  1.1.2                    pypi_0    pypi
hupper                    1.10.3                   pypi_0    pypi
idna                      3.3                      pypi_0    pypi
importlib-metadata        4.12.0                   pypi_0    pypi
intel-openmp              2021.4.0          h06a4308_3561    defaults
kiwisolver                1.4.3                    pypi_0    pypi
ld_impl_linux-64          2.38                 h1181459_1    defaults
libffi                    3.3                  he6710b0_2    defaults
libgcc-ng                 11.2.0               h1234567_1    defaults
libgfortran-ng            7.5.0               ha8ba4b0_17    defaults
libgfortran4              7.5.0               ha8ba4b0_17    defaults
libgomp                   11.2.0               h1234567_1    defaults
libstdcxx-ng              11.2.0               h1234567_1    defaults
markdown                  3.3.7                    pypi_0    pypi
markupsafe                2.1.1                    pypi_0    pypi
matplotlib                3.5.2                    pypi_0    pypi
mkl                       2021.4.0           h06a4308_640    defaults
mkl-service               2.4.0            py38h7f8727e_0    defaults
mkl_fft                   1.3.1            py38hd3c417c_0    defaults
mkl_random                1.2.2            py38h51133e4_0    defaults
mmcls                     0.23.1                   pypi_0    pypi
mmcv-full                 1.5.3                    pypi_0    pypi
mmsegmentation            0.25.0                   pypi_0    pypi
model-index               0.1.11                   pypi_0    pypi
ncurses                   6.3                  h7f8727e_2    defaults
numpy                     1.23.0                   pypi_0    pypi
numpy-base                1.22.3           py38hf524024_0    defaults
oauthlib                  3.2.0                    pypi_0    pypi
opencv-python             4.6.0.66                 pypi_0    pypi
openmim                   0.2.0                    pypi_0    pypi
openssl                   1.1.1o               h7f8727e_0    defaults
ordered-set               4.1.0                    pypi_0    pypi
packaging                 21.3                     pypi_0    pypi
pandas                    1.4.3                    pypi_0    pypi
pastedeploy               2.1.1                    pypi_0    pypi
pbkdf2                    1.3                      pypi_0    pypi
pillow                    9.1.1                    pypi_0    pypi
pip                       21.2.4           py38h06a4308_0    defaults
plaster                   1.0                      pypi_0    pypi
plaster-pastedeploy       0.7                      pypi_0    pypi
prettytable               3.3.0                    pypi_0    pypi
protobuf                  3.20.1                   pypi_0    pypi
pygments                  2.12.0                   pypi_0    pypi
pyparsing                 3.0.9                    pypi_0    pypi
pyramid                   2.0                      pypi_0    pypi
pyramid-mailer            0.15.1                   pypi_0    pypi
python                    3.8.13               h12debd9_0    defaults
python-dateutil           2.8.2                    pypi_0    pypi
python3-openid            3.2.0                    pypi_0    pypi
pytz                      2022.1                   pypi_0    pypi
pyyaml                    6.0                      pypi_0    pypi
readline                  8.1.2                h7f8727e_1    defaults
repoze-sendmail           4.4.1                    pypi_0    pypi
requests                  2.28.0                   pypi_0    pypi
requests-oauthlib         1.3.1                    pypi_0    pypi
rich                      12.4.4                   pypi_0    pypi
scipy                     1.7.3            py38hc147768_0    defaults
setuptools                61.2.0           py38h06a4308_0    defaults
six                       1.16.0             pyhd3eb1b0_1    defaults
sqlalchemy                1.4.39                   pypi_0    pypi
sqlite                    3.38.5               hc218d9a_0    defaults
tabulate                  0.8.10                   pypi_0    pypi
tensorboardx              2.5.1                    pypi_0    pypi
timm                      0.3.2                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0    defaults
torch                     1.8.0+cu111              pypi_0    pypi
torchvision               0.9.0+cu111              pypi_0    pypi
transaction               3.0.1                    pypi_0    pypi
translationstring         1.4                      pypi_0    pypi
typing-extensions         4.2.0                    pypi_0    pypi
urllib3                   1.26.9                   pypi_0    pypi
velruse                   1.1.1                    pypi_0    pypi
venusian                  3.0.0                    pypi_0    pypi
wcwidth                   0.2.5                    pypi_0    pypi
webob                     1.8.7                    pypi_0    pypi
wheel                     0.37.1             pyhd3eb1b0_0    defaults
wtforms                   3.0.1                    pypi_0    pypi
wtforms-recaptcha         0.3.2                    pypi_0    pypi
xz                        5.2.5                h7f8727e_1    defaults
yapf                      0.32.0                   pypi_0    pypi
zipp                      3.8.0                    pypi_0    pypi
zlib                      1.2.12               h7f8727e_2    defaults
zope-deprecation          4.4.0                    pypi_0    pypi
zope-interface            5.4.0                    pypi_0    pypi
zope-sqlalchemy           1.6                      pypi_0    pypi

2.训练

下面是官方提供的例子,其实跟着做就可以掌握整个工具的逻辑,我这里稍微说下

https://github.com/open-mmlab/mmsegmentation/blob/master/demo/MMSegmentation_Tutorial.ipynb

数据下载,数据也在上面的链接里

数据下载好以后别忘了转标准格式,下面这部就是转数据的

 

创建main.py文件,训练脚本,也是从上面链接里和出来的,细看代码逻辑挺清晰的,先是数据准备、然后是配置文件修改,修改以后就是构建模型,最后是训练。

import mmcv
import os.path as osp
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset
from mmcv import Config
from mmseg.apis import set_random_seed
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


data_root = './stanford_background/iccv09Data'
img_dir = 'images'
ann_dir = 'labels'
classes = ('sky', 'tree', 'road', 'grass', 'water', 'bldg', 'mntn', 'fg obj')
palette = [[128, 128, 128], [129, 127, 38], [120, 69, 125], [53, 125, 34], 
           [0, 11, 123], [118, 20, 12], [122, 81, 25], [241, 134, 51]]

@DATASETS.register_module()
class StanfordBackgroundDataset(CustomDataset):
  CLASSES = classes
  PALETTE = palette
  def __init__(self, split, **kwargs):
    super().__init__(img_suffix='.jpg', seg_map_suffix='.png', 
                     split=split, **kwargs)
    assert osp.exists(self.img_dir) and self.split is not None

# split train/val set randomly
split_dir = 'splits'
mmcv.mkdir_or_exist(osp.join(data_root, split_dir))
filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(
    osp.join(data_root, ann_dir), suffix='.png')]
with open(osp.join(data_root, split_dir, 'train.txt'), 'w') as f:
  # select first 4/5 as train set
  train_length = int(len(filename_list)*4/5)
  f.writelines(line + '\n' for line in filename_list[:train_length])
with open(osp.join(data_root, split_dir, 'val.txt'), 'w') as f:
  # select last 1/5 as train set
  f.writelines(line + '\n' for line in filename_list[train_length:])

############################################################################################ 这里训练pspnet网络,通过在脚本里修改配置文件,里面还有很多参数,想改别的打开看看细节吧
cfg = Config.fromfile('./configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py')

# Since we use only one GPU, BN is used instead of SyncBN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
# modify num classes of the model in decode/auxiliary head
cfg.model.decode_head.num_classes = 8
cfg.model.auxiliary_head.num_classes = 8

# Modify dataset type and path
cfg.dataset_type = 'StanfordBackgroundDataset'
cfg.data_root = data_root

cfg.data.samples_per_gpu = 8
cfg.data.workers_per_gpu = 1

cfg.img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
cfg.crop_size = (256, 256)
cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(320, 240), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **cfg.img_norm_cfg),
    dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]

cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(320, 240),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **cfg.img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]


cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = img_dir
cfg.data.train.ann_dir = ann_dir
cfg.data.train.pipeline = cfg.train_pipeline
cfg.data.train.split = 'splits/train.txt'

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = img_dir
cfg.data.val.ann_dir = ann_dir
cfg.data.val.pipeline = cfg.test_pipeline
cfg.data.val.split = 'splits/val.txt'

cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.img_dir = img_dir
cfg.data.test.ann_dir = ann_dir
cfg.data.test.pipeline = cfg.test_pipeline
cfg.data.test.split = 'splits/val.txt'

# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
#预模型位置,链接里有下载地址,自己去下载一下
cfg.load_from = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'

# Set up working dir to save files and logs.
#训练生成的文件位置
cfg.work_dir = './run/pspnet/'

cfg.runner.max_iters = 100000
cfg.log_config.interval = 20
cfg.evaluation.interval = 500
cfg.checkpoint_config.interval = 200

# Set seed to facitate reproducing the result
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

# Let's have a look at the final config used for training
print(f'Config:\n{cfg.pretty_text}')

# Build the dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_segmentor(cfg.model)
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
                meta=dict())

上面操作完后直接运行main.py就可以开始训练,很方便。

3.预测

下面是我用的预测代码,类别和颜色映射肯定是要有的,还有配置文件。 

import mmcv
import os.path as osp
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset
from mmcv import Config
from mmseg.apis import set_random_seed
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor, inference_segmentor, init_segmentor, show_result_pyplot
import os
import cv2
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


data_root = './stanford_background/iccv09Data'
img_dir = 'images'
ann_dir = 'labels'
classes = ('sky', 'tree', 'road', 'grass', 'water', 'bldg', 'mntn', 'fg obj')
palette = [[128, 128, 128], [129, 127, 38], [120, 69, 125], [53, 125, 34], 
           [0, 11, 123], [118, 20, 12], [122, 81, 25], [241, 134, 51]]

@DATASETS.register_module()
class StanfordBackgroundDataset(CustomDataset):
  CLASSES = classes
  PALETTE = palette
  def __init__(self, split, **kwargs):
    super().__init__(img_suffix='.jpg', seg_map_suffix='.png', 
                     split=split, **kwargs)
    assert osp.exists(self.img_dir) and self.split is not None


############################################################################################
cfg = Config.fromfile('./configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py')

# Since we use only one GPU, BN is used instead of SyncBN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
# modify num classes of the model in decode/auxiliary head
cfg.model.decode_head.num_classes = 8
cfg.model.auxiliary_head.num_classes = 8

# Modify dataset type and path
cfg.dataset_type = 'StanfordBackgroundDataset'
cfg.data_root = data_root

cfg.data.samples_per_gpu = 2
cfg.data.workers_per_gpu = 0

cfg.img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
cfg.crop_size = (256, 256)
cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(320, 240), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **cfg.img_norm_cfg),
    dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]

cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(320, 240),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **cfg.img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]


cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = img_dir
cfg.data.train.ann_dir = ann_dir
cfg.data.train.pipeline = cfg.train_pipeline
cfg.data.train.split = 'splits/train.txt'

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = img_dir
cfg.data.val.ann_dir = ann_dir
cfg.data.val.pipeline = cfg.test_pipeline
cfg.data.val.split = 'splits/val.txt'

cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.img_dir = img_dir
cfg.data.test.ann_dir = ann_dir
cfg.data.test.pipeline = cfg.test_pipeline
cfg.data.test.split = 'splits/val.txt'

# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
cfg.load_from = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'

# Set up working dir to save files and logs.
cfg.work_dir = './run'

cfg.runner.max_iters = 200
cfg.log_config.interval = 10
cfg.evaluation.interval = 200
cfg.checkpoint_config.interval = 200

# Set seed to facitate reproducing the result
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

# Let's have a look at the final config used for training
print(f'Config:\n{cfg.pretty_text}')

config_file = cfg
checkpoints_file = './run/pspnet/latest.pth'
model = init_segmentor(config_file, checkpoints_file, device='cuda:0')

img = mmcv.imread('./stanford_background/iccv09Data/images/6000124.jpg')

result = inference_segmentor(model, img)
print(result)
#plt.figure(figsize=(8, 6))
# show_result_pyplot(model, img, result, palette)
#model.show_result(img, result, show=True)

cv2.imwrite('./re.jpg', result[0])

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

如雾如电

随缘

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值