论文模型对比需要对之前模型复现 从0记录实现过程
基础配置:pytorch+cuda,系统:win, 模型:fcn-8s,数据集:Aeroscapes
目录
基础代码来源
step1:链接: github,然后用gitclone到本地或者download zip解压都可以。会看到一个庞大的代码文件如下。
step2:下载所需要的包,即项目目录中的requirements.txt
conda activate yourenv#这里激活你自己的虚拟环境 把包下载到自己的虚拟环境中
pip install -r requirements.txt #下载requirements.txt中的所有包
每个人都会出现各种缺包的情况 一一找出并下载。
下载数据集,AeroscapesAeroscapes
训练
代码目录结构如下(接下来用的数据集是aeroscapes,所以要改一些数据集加载的代码)
pytorch-fcn-main/
├── .github/
│ └── workflows/
│ └── FUNDING.yml
├── .readme/
│ └── fcn8s_iter28000.jpg
├── examples/
│ └── voc/
│ ├── .gitignore
│ ├── download_dataset.sh
│ ├── evaluate.py
│ ├── learning_curve.py
│ ├── model_caffe_to_pytorch.py
│ ├── README.md
│ ├── speedtest.py
│ ├── summarize_logs.py
│ ├── train_fcn8s.py
│ ├── train_fcn8s_atonce.py
│ ├── train_fcn16s.py
│ ├── train_fcn32s.py
│ └── view_log
├── tests/
│ └── models_tests/
│ └── test_fcn32s.py
├── torchfcn/
│ ├── datasets/
│ ├── ext/
│ └── models/
│ ├── __init__.py
│ ├── trainer.py
│ └── utils.py
├── .gitignore
├── .gitmodules
├── LICENSE
├── MANIFEST.in
├── README.md
└── requirements.txt
└── setup.cfg
└── setup.py
首先我要用aeroscpaes数据集训练fcn8s,所以开始对examples/voc/train_fcn8s.py中的代码进行修改。其中超参数的代码如下:修改的地方用注释注明了原因
def main():
#================================
#所有超参数
#================================
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('-g', '--gpu', type=int, default=0, required=False, help='gpu id') #这里的gpu直接指定了默认值0,然后required=False
parser.add_argument('--resume', help='checkpoint path')
parser.add_argument(
'--max-iteration', type=int, default=100000, help='max iteration'
)
parser.add_argument(
'--lr', type=float, default=1.0e-14, help='learning rate',
)
parser.add_argument(
'--weight-decay', type=float, default=0.0005, help='weight decay',
)
parser.add_argument(
'--momentum', type=float, default=0.99, help='momentum',
)
parser.add_argument(
'--pretrained-model',
default='/root/autodl-tmp/torchfcn/models/fcn8s-heavy-pascal.pth',
help='pretrained model of FCN8s',
)#这里之前作者是去谷歌下载pth,下载会报错,原因你懂的。所以我直接用vpn下载好pth后使用绝对路径加载。
args = parser.parse_args()
args.model = 'FCN8s'
now = datetime.datetime.now()
args.out = osp.join(here, 'logs', now.strftime('%Y%m%d_%H%M%S.%f'))
os.makedirs(args.out)
with open(osp.join(args.out, 'config.yaml'), 'w') as f:
yaml.safe_dump(args.__dict__, f, default_flow_style=False)
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
cuda = torch.cuda.is_available()
torch.manual_seed(1337)
if cuda:
torch.cuda.manual_seed(1337)
加载数据集的地方也做了修改。
# 1. dataset
root = osp.expanduser('/root/autodl-tmp/torchfcn/datasets')#这里用了绝对路径加载数据集
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
VOC2007Seg(root, split='train', transform=True),#这里原本代码是SBDClassSeg来加载训练集和VOC2011ClassSeg加载下面的测试集,而我的aeroscapes数据集是VOC2007格式,所以自己将原本的这两个类修改了,这两个类在voc.py文件中。(下面我会给出完整的修改后的voc.py文件)
batch_size=1, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(
VOC2007Seg(root, split='val', transform=True),
batch_size=1, shuffle=False, **kwargs)
完整的新voc.py
#!/usr/bin/env python
import collections
import os.path as osp
import numpy as np
import PIL.Image
import scipy.io
import torch
from torch.utils import data
import os
class VOCClassSegBase(data.Dataset):
class_names = np.array([
'background',
'aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'diningtable',
])
mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
def __init__(self, root, split='train', transform=False):
self.root = root
self.split = split
self._transform = transform
# VOC2011 and others are subset of VOC2012
dataset_dir = osp.join(self.root, 'VOC2007')
self.files = collections.defaultdict(list)
for split in ['train', 'val']:
imgsets_file = osp.join(dataset_dir, 'ImageSets', f'{split}.txt')###
for did in open(imgsets_file):
did = did.strip()
img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did)
lbl_file = osp.join(dataset_dir, 'SegmentationClass/%s.png' % did)
self.files[split].append({'img': img_file, 'lbl': lbl_file,})
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
data_file = self.files[self.split][index]
# load image
img_file = data_file['img']
img = PIL.Image.open(img_file).convert('RGB')
img = np.array(img, dtype=np.uint8)
# load label
lbl_file = data_file['lbl']
lbl = PIL.Image.open(lbl_file)
lbl = np.array(lbl, dtype=np.int32)
lbl[lbl == 255] = -1 # Ignore index
if self._transform:
return self.transform(img, lbl)
else:
return img, lbl
def transform(self, img, lbl):
img = img[:, :, ::-1] # RGB -> BGR
img = img.astype(np.float64)
img -= self.mean_bgr
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).float()
lbl = torch.from_numpy(lbl).long()
return img, lbl
def untransform(self, img, lbl):
img = img.numpy()
img = img.transpose(1, 2, 0)
img += self.mean_bgr
img = img.astype(np.uint8)
img = img[:, :, ::-1]
lbl = lbl.numpy()
return img, lbl
class VOC2007Seg(VOCClassSegBase):
def __init__(self, root, split='train', transform=False):
super(VOC2007Seg, self).__init__(root, split=split, transform=transform)
dataset_dir = osp.join(self.root, 'VOC2007')
# 根据传入的split参数决定使用trn.txt还是val.txt
if split == 'train':
imgsets_file = osp.join(dataset_dir, 'ImageSets/train.txt')
elif split == 'val':
imgsets_file = osp.join(dataset_dir, 'ImageSets/val.txt')
else:
raise ValueError(f"Unsupported split: {split}")
self.files = collections.defaultdict(list)
for did in open(imgsets_file):
did = did.strip()
img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did)
lbl_file = osp.join(dataset_dir, 'SegmentationClass/%s.png' % did)
self.files[split].append({'img': img_file, 'lbl': lbl_file})
def __getitem__(self, index):
data_file = self.files[self.split][index]
img_file = data_file['img']
img = PIL.Image.open(img_file).convert('RGB')
img = np.array(img, dtype=np.uint8)
lbl_file = data_file['lbl']
lbl = PIL.Image.open(lbl_file)
lbl = np.array(lbl, dtype=np.int32)
lbl[lbl == 255] = -1 # Ignore index
if self._transform:
return self.transform(img, lbl)
else:
return img, lbl
train_fcn8s.py中调用train_fcn32s.py的git_hash函数全部删除了,因为这个函数就是输出当前code的版本,这里又会涉及到代理问题,所以删除了不影响
此时运行train_fcn8s.py成功,在examples/voc/logs中保存结果。
预测
 解压model_best.pth.tar,模型现在加载这个model_best.pth测试即可。