Github复现之deeplab v3+(用自己的遥感数据集训练)

原始链接:https://github.com/jfzhang95/pytorch-deeplab-xception
1.数据准备
训练
验证
images文件夹和labels文件夹内的图像和标签名是一一对应的,名字是一样的,标签的具体内容应该是0,1,2,3这样代表类别的数据。文件夹名字最好和我的都一样,因为代码里有的地方写了文件名。

2.数据相关代码
(1)数据读入,创建对应的脚本放进对应的位置
dataloaders/datasets/own_data.py

import os
import cv2
import random
import numpy as np
from shutil import copyfile, move
from osgeo import gdal
from PIL import Image
import torch
import torch.utils.data as data
from torchvision import transforms
from dataloaders import custom_transforms as tr

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data

def write_img(filename,im_proj,im_geotrans,im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

def gdal_loader(img_path, mask_path):
    _,_,_,_,img = read_img(img_path)
    _,_,_,_,mask = read_img(mask_path)
    img = img[0:3] / 255.0   #这里因为数据原来是4波段,取了三个波段训练,4波段要改的地方比较麻烦,暂时不改
    return img, mask

def read_own_data(root_path, split = 'train'):
    images = []
    masks = []

    image_root = os.path.join(root_path, split + '/images')
    gt_root = os.path.join(root_path, split + '/labels')


    for image_name in os.listdir(image_root):
        image_path = os.path.join(image_root, image_name)
        label_path = os.path.join(gt_root, image_name)

        images.append(image_path)
        masks.append(label_path)

    return images, masks

def own_data_loader(img_path, mask_path):
    img = cv2.imread(img_path)
    img = img/255.0
    # img = cv2.resize(img, (512, 512))
    # mask = np.array(Image.open(mask_path))
    # mask = cv2.resize(mask, (512, 512))
    mask = np.expand_dims(mask, axis=2)
    # img = np.array(img, np.float32).transpose(2, 0, 1) / 255.0 * 3.2 - 1.6
    # mask = np.array(mask, np.float32).transpose(2, 0, 1) / 255.0
    # mask[mask >= 0.5] = 1
    # mask[mask <= 0.5] = 0
    mask = mask/255.0
    # mask = abs(mask-1)
    img = np.array(img, np.float32).transpose(2, 0, 1)
    mask = np.array(mask, np.float32).transpose(2, 0, 1)
    return img, mask


class ImageFolder(data.Dataset):
    
    NUM_CLASSES = 4

    def __init__(self, args, split='train'):
        self.args = args
        self.root = self.args.root_path
        self.split = split
        self.images, self.labels = read_own_data(self.root, self.split)
    
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            # tr.RandomHorizontalFlip(),
            # tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
            # tr.RandomGaussianBlur(),
            tr.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            tr.ToTensor()
            ])

        return composed_transforms(sample)

    def transform_val(self, sample):

        composed_transforms = transforms.Compose([
            # tr.FixScaleCrop(crop_size=self.args.crop_size),
            tr.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            tr.ToTensor()
            ])

        return composed_transforms(sample)

    def __getitem__(self, index):
        # img, mask = own_data_loader(self.images[index], self.labels[index])
        img, mask = gdal_loader(self.images[index], self.labels[index])
        # img = torch.Tensor(img)
        # mask = torch.Tensor(mask)

        sample = {'image': img, 'label': mask}
        # sample = {'image': _img, 'label': _target}

        # if self.split == "train":
        #     return self.transform_tr(sample)
        # elif self.split == 'val':
        #     return self.transform_val(sample)
        return sample

    def __len__(self):
        assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'
        return len(self.images)

(2)修改dataloaders/init.py文件,添加自己数据的调用入口

from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbd, own_data
from torch.utils.data import DataLoader

def make_data_loader(args, **kwargs):

    if args.dataset == 'pascal':
        train_set = pascal.VOCSegmentation(args, split='train')
        val_set = pascal.VOCSegmentation(args, split='val')
        if args.use_sbd:
            sbd_train = sbd.SBDSegmentation(args, split=['train', 'val'])
            train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set])

        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = None

        return train_loader, val_loader, test_loader, num_class

    elif args.dataset == 'cityscapes':
        train_set = cityscapes.CityscapesSegmentation(args, split='train')
        val_set = cityscapes.CityscapesSegmentation(args, split='val')
        test_set = cityscapes.CityscapesSegmentation(args, split='test')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs)

        return train_loader, val_loader, test_loader, num_class

    elif args.dataset == 'coco':
        train_set = coco.COCOSegmentation(args, split='train')
        val_set = coco.COCOSegmentation(args, split='val')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = None
        return train_loader, val_loader, test_loader, num_class

    elif args.dataset == 'own':
        train_set = own_data.ImageFolder(args, split='train')
        val_set = own_data.ImageFolder(args, split='val')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, drop_last=True)
        test_loader = None
        return train_loader, val_loader, test_loader, num_class

    else:
        raise NotImplementedError

3.训练代码修改,要改的地方不多,因为数据的导入已经改到适应训练了,下面有两个地方要改,给了中文注释的,认真看下
train.py,训练的时候会自动创建run文件夹放模型

import argparse
import os
import numpy as np
from tqdm import tqdm

from mypath import Path
from dataloaders import make_data_loader
from modeling.sync_batchnorm.replicate import patch_replication_callback
from modeling.deeplab import *
from utils.loss import SegmentationLosses
from utils.calculate_weights import calculate_weigths_labels
from utils.lr_scheduler import LR_Scheduler
from utils.saver import Saver
from utils.summaries import TensorboardSummary
from utils.metrics import Evaluator

class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda().float(), target.cuda()
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(image)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            # Show 10 * 3 inference results each epoch
            # if i % (num_img_tr // 10) == 0:
            #     global_step = i + num_img_tr * epoch
            #     self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step)   #这里注释掉,因为原始的可视化针对的是21类的voc数据,这里不好用,暂时不要

        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)


    def validation(self, epoch):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda().float(), target.cuda()
            with torch.no_grad():
                output = self.model(image)
            loss = self.criterion(output, target)
            test_loss += loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

def main():
    parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training")
    parser.add_argument('--backbone', type=str, default='resnet',
                        choices=['resnet', 'xception', 'drn', 'mobilenet'],
                        help='backbone name (default: resnet)')
    parser.add_argument('--out-stride', type=int, default=16,
                        help='network output stride (default: 8)')
    parser.add_argument('--dataset', type=str, default='pascal',
                        choices=['pascal', 'coco', 'cityscapes', 'own'],
                        help='dataset name (default: pascal)')

    parser.add_argument('--root_path', type=str, default='./data/',
                        help='the own path root')  **#这里加下根目录的参数**


    parser.add_argument('--use-sbd', action='store_true', default=True,
                        help='whether to use SBD dataset (default: True)')
    parser.add_argument('--workers', type=int, default=4,
                        metavar='N', help='dataloader threads')
    parser.add_argument('--base-size', type=int, default=512,
                        help='base image size')
    parser.add_argument('--crop-size', type=int, default=448,
                        help='crop image size')
    parser.add_argument('--sync-bn', type=bool, default=None,
                        help='whether to use sync bn (default: auto)')
    parser.add_argument('--freeze-bn', type=bool, default=False,
                        help='whether to freeze bn parameters (default: False)')
    parser.add_argument('--loss-type', type=str, default='ce',
                        choices=['ce', 'focal'],
                        help='loss func type (default: ce)')
    # training hyper params
    parser.add_argument('--epochs', type=int, default=None, metavar='N',
                        help='number of epochs to train (default: auto)')
    parser.add_argument('--start_epoch', type=int, default=0,
                        metavar='N', help='start epochs (default:0)')
    parser.add_argument('--batch-size', type=int, default=None,
                        metavar='N', help='input batch size for \
                                training (default: auto)')
    parser.add_argument('--test-batch-size', type=int, default=None,
                        metavar='N', help='input batch size for \
                                testing (default: auto)')
    parser.add_argument('--use-balanced-weights', action='store_true', default=False,
                        help='whether to use balanced weights (default: False)')
    # optimizer params
    parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (default: auto)')
    parser.add_argument('--lr-scheduler', type=str, default='poly',
                        choices=['poly', 'step', 'cos'],
                        help='lr scheduler mode: (default: poly)')
    parser.add_argument('--momentum', type=float, default=0.9,
                        metavar='M', help='momentum (default: 0.9)')
    parser.add_argument('--weight-decay', type=float, default=5e-4,
                        metavar='M', help='w-decay (default: 5e-4)')
    parser.add_argument('--nesterov', action='store_true', default=False,
                        help='whether use nesterov (default: False)')
    # cuda, seed and logging
    parser.add_argument('--no-cuda', action='store_true', default=
                        False, help='disables CUDA training')
    parser.add_argument('--gpu-ids', type=str, default='0',
                        help='use which gpu to train, must be a \
                        comma-separated list of integers only (default=0)')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    # checking point
    parser.add_argument('--resume', type=str, default=None,
                        help='put the path to resuming file if needed')
    parser.add_argument('--checkname', type=str, default=None,
                        help='set the checkpoint name')
    # finetuning pre-trained models
    parser.add_argument('--ft', action='store_true', default=False,
                        help='finetuning on a different dataset')
    # evaluation option
    parser.add_argument('--eval-interval', type=int, default=1,
                        help='evaluuation interval (default: 1)')
    parser.add_argument('--no-val', action='store_true', default=False,
                        help='skip validation during training')

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        try:
            args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
        except ValueError:
            raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only')

    if args.sync_bn is None:
        if args.cuda and len(args.gpu_ids) > 1:
            args.sync_bn = True
        else:
            args.sync_bn = False

    # default settings for epochs, batch_size and lr
    if args.epochs is None:
        epoches = {
            'coco': 30,
            'cityscapes': 200,
            'pascal': 50,
        }
        args.epochs = epoches[args.dataset.lower()]

    if args.batch_size is None:
        args.batch_size = 4 * len(args.gpu_ids)

    if args.test_batch_size is None:
        args.test_batch_size = args.batch_size

    if args.lr is None:
        lrs = {
            'coco': 0.1,
            'cityscapes': 0.01,
            'pascal': 0.007,
        }
        args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size


    if args.checkname is None:
        args.checkname = 'deeplab-'+str(args.backbone)
    print(args)
    torch.manual_seed(args.seed)
    trainer = Trainer(args)
    print('Starting Epoch:', trainer.args.start_epoch)
    print('Total Epoches:', trainer.args.epochs)
    for epoch in range(trainer.args.start_epoch, trainer.args.epochs):
        trainer.training(epoch)
        if not trainer.args.no_val and epoch % args.eval_interval == (args.eval_interval - 1):
            trainer.validation(epoch)

    trainer.writer.close()

if __name__ == "__main__":
   main()

4.预测代码,原始的代码没有预测,这里给出一下
inference.py

import argparse
import os
import numpy as np 
import tqdm
import torch
import json
from osgeo import gdal
from easydict import EasyDict

from PIL import Image
from modeling.deeplab import *
# from utils.dataloader import seg_utils, make_data_loader
from utils.metrics import Evaluator
from torchvision import transforms

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data

def write_img(filename,im_proj,im_geotrans,im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])


class Tester(object):
    def __init__(self, args):
        if not os.path.isfile(args.model):
            raise RuntimeError("no checkpoint found at '{}'".format(args.model))
        self.args = args
        # self.color_map = seg_utils.get_pascal_labels()
        self.nclass = args.num_class

        #Define model
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=False,
                        freeze_bn=False)
        
        self.model = model
        device = torch.device('cpu')
        checkpoint = torch.load(args.model, map_location=device)
        self.model.load_state_dict(checkpoint['state_dict'])
        self.evaluator = Evaluator(self.nclass)

    # def save_image(self, array, id, op):
    #     text = 'gt'
    #     if op == 0:
    #         text = 'pred'
    #     file_name = id
    #     r = array.copy()
    #     g = array.copy()
    #     b = array.copy()

    #     for i in range(self.nclass):
    #         r[array == i] = self.color_map[i][0]
    #         g[array == i] = self.color_map[i][1]
    #         b[array == i] = self.color_map[i][2]
    
    #     rgb = np.dstack((r, g, b))
    #     save_img = Image.fromarray(rgb.astype('uint8'))
    #     save_img.save(self.args.save+os.sep+file_name)


    def inference(self):
        self.model.eval()
        self.evaluator.reset()
        
        DATA_DIR = self.args.data_path
        SAVE_DIR = self.args.save_path
                        
        
        for idx, test_file in enumerate(os.listdir(DATA_DIR)):
            if test_file == '.DS_Store':
                continue
            # test_img = Image.open(os.path.join(DATA_DIR, test_file)).convert('RGB')
            im_proj,im_geotrans,im_width, im_height,test_img = read_img(os.path.join(DATA_DIR, test_file))
            test_array = np.array(test_img).astype(np.float32)
            full_name = os.path.split(test_file)
            image_id, extension = image_id, extension = full_name[1][0:-4], full_name[1].split('.')[-1]

            # Normalize
            test_img = test_img[0:3]
            test_img = test_img / 255.0
            
            test_img = np.expand_dims(test_img, axis=0)
            test_crop_tensor = torch.from_numpy(test_img)
            test_crop_tensor = torch.tensor(test_crop_tensor, dtype=torch.float32)
            with torch.no_grad():
                output = self.model(test_crop_tensor)
            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            inference_imgs = pred[0][:, :]

            print('inference ... {}/{}'.format(idx+1, len(os.listdir(DATA_DIR))))
            # gray mode
            # save_image = Image.fromarray(inference_imgs.astype('uint8'))
            # save_image.save(os.path.join(self.args.save_path,image_id+'.'+extension))

            write_img(os.path.join(self.args.save_path,image_id+'.'+extension),im_proj,im_geotrans,inference_imgs.astype('uint8'))



def main():
    with open('config.json') as f:
        args = json.load(f)
    args = EasyDict(args['inference'])

    tester = Tester(args)

    if args.inference:
        print('predict...')
        tester.inference()

if __name__ == "__main__":
    main()

预测代码对应的json文件,放在同级目录
config.json

{
    "inference": {
        "inference": 1,
        "num_class": 4, 
        "backbone": "resnet",
        "out_stride": 16,
        "crop_size": 448,   #这个参数和train.py的参数部分保持一致
        "data_path": "./deeplab-xception/data/val/images",
        "save_path": "./deeplab-xception/result",
        "model": "./deeplab-xception/run/own/deeplab-resnet/model_best.pth.tar"
    }
}

训练使用的命令行命令:

python train.py --backbone resnet --lr 0.002 --workers 0 --epochs 50 --batch-size 3 --gpu-ids 0 --checkname deeplab-resnet --eval-interval 1 --dataset own

预测的时候改好json文件直接命令行运行inference.py就行了

因为一开始是四波段,用了gdal读取数据,没办法对应做数据增强的一些操作,感兴趣的可以把数据转为png直接用三波段参考pascal.py读入数据,用源码提供的数据增强试试

  • 12
    点赞
  • 92
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 49
    评论
在Win10系统下,使用DeepLab V3进行语义分割训练自己的数据集可以通过以下步骤实现: 1. 准备数据集:首先,需要进行数据集的准备工作。收集大量的图像数据,并为每张图像标注像素级别的语义标签。确保训练图像和标签数据是一一对应的。 2. 安装依赖环境:在Win10系统下,首先需要安装Python和TensorFlow等深度学习框架,并配置好相应的环境变量。确保能够成功导入相关的库和模块。 3. 下载DeepLab V3:从GitHub上下载DeepLab V3的源代码,并解压到本地目录。在命令行中切换到DeepLab V3的根目录。 4. 数据预处理:使用脚本文件对数据集进行预处理,将图像和标签数据转换成模型可接受的格式。这可以通过运行预处理脚本来完成。 5. 配置参数:在配置文件中设置相关的训练参数,如训练图像的路径、标签的路径、模型的参数等。可以根据实际需要进行调整。 6. 运行训练:在命令行中运行训练脚本,该脚本会调用DeepLab V3模型进行训练。根据配置文件中的设置,模型将使用训练数据进行迭代训练,以优化模型的性能。 7. 评估模型:训练完成后,可以运行评估脚本对训练得到的模型进行评估。该脚本将使用测试数据进行预测,并计算出预测结果的准确性。 8. 使用模型:训练完成后,可以使用已训练好的模型对新的图像进行语义分割。通过在命令行中运行预测脚本,将输入图像作为参数进行预测,即可得到相应的语义分割结果。 以上是在Win10系统下使用DeepLab V3进行语义分割训练自己的数据集的基本步骤。根据具体情况和需求,可能还需要进行一些额外的调整和改进。
评论 49
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

如雾如电

随缘

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值