FCN训练Aeroscapes数据集

论文模型对比需要对之前模型复现 从0记录实现过程

基础配置:pytorch+cuda,系统:win, 模型:fcn-8s,数据集:Aeroscapes

目录

基础代码来源

  step1:链接: github,然后用gitclone到本地或者download zip解压都可以。会看到一个庞大的代码文件如下。在这里插入图片描述
  step2:下载所需要的包,即项目目录中的requirements.txt

conda activate yourenv#这里激活你自己的虚拟环境 把包下载到自己的虚拟环境中
pip install -r requirements.txt #下载requirements.txt中的所有包

每个人都会出现各种缺包的情况 一一找出并下载。
   下载数据集,AeroscapesAeroscapes

训练

代码目录结构如下(接下来用的数据集是aeroscapes,所以要改一些数据集加载的代码)

pytorch-fcn-main/
├── .github/
│   └── workflows/
│       └── FUNDING.yml
├── .readme/
│   └── fcn8s_iter28000.jpg
├── examples/
│   └── voc/
│       ├── .gitignore
│       ├── download_dataset.sh
│       ├── evaluate.py
│       ├── learning_curve.py
│       ├── model_caffe_to_pytorch.py
│       ├── README.md
│       ├── speedtest.py
│       ├── summarize_logs.py
│       ├── train_fcn8s.py
│       ├── train_fcn8s_atonce.py
│       ├── train_fcn16s.py
│       ├── train_fcn32s.py
│       └── view_log
├── tests/
│   └── models_tests/
│       └── test_fcn32s.py
├── torchfcn/
│   ├── datasets/
│   ├── ext/
│   └── models/
│       ├── __init__.py
│       ├── trainer.py
│       └── utils.py
├── .gitignore
├── .gitmodules
├── LICENSE
├── MANIFEST.in
├── README.md
└── requirements.txt
└── setup.cfg
└── setup.py

  首先我要用aeroscpaes数据集训练fcn8s,所以开始对examples/voc/train_fcn8s.py中的代码进行修改。其中超参数的代码如下:修改的地方用注释注明了原因

def main():
    #================================
    #所有超参数
    #================================
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument('-g', '--gpu', type=int, default=0, required=False, help='gpu id') #这里的gpu直接指定了默认值0,然后required=False
    parser.add_argument('--resume', help='checkpoint path')
    parser.add_argument(
        '--max-iteration', type=int, default=100000, help='max iteration'
    )
    parser.add_argument(
        '--lr', type=float, default=1.0e-14, help='learning rate',
    )
    parser.add_argument(
        '--weight-decay', type=float, default=0.0005, help='weight decay',
    )
    parser.add_argument(
        '--momentum', type=float, default=0.99, help='momentum',
    )
    parser.add_argument(
        '--pretrained-model',
        default='/root/autodl-tmp/torchfcn/models/fcn8s-heavy-pascal.pth',
        help='pretrained model of FCN8s',
    )#这里之前作者是去谷歌下载pth,下载会报错,原因你懂的。所以我直接用vpn下载好pth后使用绝对路径加载。
    args = parser.parse_args()

    args.model = 'FCN8s'

    now = datetime.datetime.now()
    args.out = osp.join(here, 'logs', now.strftime('%Y%m%d_%H%M%S.%f'))

    os.makedirs(args.out)
    with open(osp.join(args.out, 'config.yaml'), 'w') as f:
        yaml.safe_dump(args.__dict__, f, default_flow_style=False)

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

  加载数据集的地方也做了修改。

# 1. dataset
    root = osp.expanduser('/root/autodl-tmp/torchfcn/datasets')#这里用了绝对路径加载数据集
    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}

    train_loader = torch.utils.data.DataLoader(
        VOC2007Seg(root, split='train', transform=True),#这里原本代码是SBDClassSeg来加载训练集和VOC2011ClassSeg加载下面的测试集,而我的aeroscapes数据集是VOC2007格式,所以自己将原本的这两个类修改了,这两个类在voc.py文件中。(下面我会给出完整的修改后的voc.py文件)
        batch_size=1, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(
        VOC2007Seg(root, split='val', transform=True),
        batch_size=1, shuffle=False, **kwargs)

  完整的新voc.py

#!/usr/bin/env python

import collections
import os.path as osp

import numpy as np
import PIL.Image
import scipy.io
import torch
from torch.utils import data
import os

class VOCClassSegBase(data.Dataset):

    class_names = np.array([
        'background',
        'aeroplane',
        'bicycle',
        'bird',
        'boat',
        'bottle',
        'bus',
        'car',
        'cat',
        'chair',
        'cow',
        'diningtable',
    ])
    mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])

    def __init__(self, root, split='train', transform=False):
        self.root = root
        self.split = split
        self._transform = transform

        # VOC2011 and others are subset of VOC2012
        dataset_dir = osp.join(self.root, 'VOC2007')
        self.files = collections.defaultdict(list)
        for split in ['train', 'val']:
            imgsets_file = osp.join(dataset_dir, 'ImageSets', f'{split}.txt')###

            for did in open(imgsets_file):
                did = did.strip()
                img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did)
                lbl_file = osp.join(dataset_dir, 'SegmentationClass/%s.png' % did)
                self.files[split].append({'img': img_file, 'lbl': lbl_file,})

    def __len__(self):
        return len(self.files[self.split])

    def __getitem__(self, index):
        data_file = self.files[self.split][index]
        # load image
        img_file = data_file['img']
        img = PIL.Image.open(img_file).convert('RGB')
        img = np.array(img, dtype=np.uint8)
        # load label
        lbl_file = data_file['lbl']
        lbl = PIL.Image.open(lbl_file)
        lbl = np.array(lbl, dtype=np.int32)
        lbl[lbl == 255] = -1 # Ignore index
        if self._transform:
            return self.transform(img, lbl)
        else:
            return img, lbl

    def transform(self, img, lbl):
        img = img[:, :, ::-1]  # RGB -> BGR
        img = img.astype(np.float64)
        img -= self.mean_bgr
        img = img.transpose(2, 0, 1)
        img = torch.from_numpy(img).float()
        lbl = torch.from_numpy(lbl).long()
        return img, lbl

    def untransform(self, img, lbl):
        img = img.numpy()
        img = img.transpose(1, 2, 0)
        img += self.mean_bgr
        img = img.astype(np.uint8)
        img = img[:, :, ::-1]
        lbl = lbl.numpy()
        return img, lbl


class VOC2007Seg(VOCClassSegBase):
    def __init__(self, root, split='train', transform=False):
        super(VOC2007Seg, self).__init__(root, split=split, transform=transform)
        dataset_dir = osp.join(self.root, 'VOC2007')

        # 根据传入的split参数决定使用trn.txt还是val.txt
        if split == 'train':
            imgsets_file = osp.join(dataset_dir, 'ImageSets/train.txt')
        elif split == 'val':
            imgsets_file = osp.join(dataset_dir, 'ImageSets/val.txt')
        else:
            raise ValueError(f"Unsupported split: {split}")

        self.files = collections.defaultdict(list)
        for did in open(imgsets_file):
            did = did.strip()
            img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did)
            lbl_file = osp.join(dataset_dir, 'SegmentationClass/%s.png' % did)
            self.files[split].append({'img': img_file, 'lbl': lbl_file})

    def __getitem__(self, index):
        data_file = self.files[self.split][index]
        img_file = data_file['img']
        img = PIL.Image.open(img_file).convert('RGB')
        img = np.array(img, dtype=np.uint8)
        lbl_file = data_file['lbl']
        lbl = PIL.Image.open(lbl_file)
        lbl = np.array(lbl, dtype=np.int32)
        lbl[lbl == 255] = -1  # Ignore index
        if self._transform:
            return self.transform(img, lbl)
        else:
            return img, lbl

  train_fcn8s.py中调用train_fcn32s.py的git_hash函数全部删除了,因为这个函数就是输出当前code的版本,这里又会涉及到代理问题,所以删除了不影响
  此时运行train_fcn8s.py成功,在examples/voc/logs中保存结果。
在这里插入图片描述

预测

  解压model_best.pth.tar,模型现在加载这个model_best.pth测试即可。
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小马敲马

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值