目的:运行并粗略看懂ML-GCN的代码。注:此代码为更改后的代码,结构与原来模型相差甚远。但代码结构大致相同。
代码地址:https://github.com/chenzhaomin123/ML_GCN
论文地址 https://arxiv.org/abs/1904.03582
目录
一、相关依赖项下载
1.1 程序及数据
https://github.com/chenzhaomin123/ML_GCN
1.2 数据集
微软发布的 COCO 数据库是一个大型图像数据集, 专为对象检测、分割、人体关键点检测、语义分割和字幕生成而设计。
COCO 数据库的网址是:
- MS COCO 数据集主页:http://mscoco.org/
- Github 网址:https://github.com/Xinering/cocoapi
- 关于 API 更多的细节在网站: http://mscoco.org/dataset/#download
运用coco2014数据集,数据集较大
-rw-rw-r-- 1 xingxiangrui xingxiangrui 13G Apr 23 11:20 train2014.zip
-rw-rw-r-- 1 xingxiangrui xingxiangrui 6.2G Apr 23 11:23 val2014.zip
训练集与验证集数量
train2014$ ls -l |grep "^-"|wc -l
82783
val2014$ ls -l |grep "^-"|wc -l
40504
尺寸为640*426
1.3 放入对应位置
调用关系及位置:
in general_trian.py
train_dataset = COCO2014(args.data, phase='train', inp_name=Config.INP_NAME, is_grouping=True) # fixme
DATA = 'data/data/coco'
in coco.py
tmpdir = os.path.join(root, 'tmp/')
data = os.path.join(root, 'data/')
if not os.path.exists(data):
os.makedirs(data)
if not os.path.exists(tmpdir):
os.makedirs(tmpdir)
if phase == 'train':
filename = 'train2014.zip'
elif phase == 'val':
filename = 'val2014.zip'
cached_file = os.path.join(tmpdir, filename)
# extract file
img_data = os.path.join(data, filename.split('.')[0])
if not os.path.exists(img_data):
print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=data))
command = 'unzip {} -d {}'.format(cached_file, data)
os.system(command)
root= 'data/data/coco'
tmpdir='data/data/coco/tmp/'
data='data/data/coco /data/'
路径位置应该为这样 /data/data/coco/ data/train2014.zip
解压后图片位置为: /data/data/coco/ data/
(如果不按照这个路径放好数据,程序会重新下载并安装)
1.4 标注 annotations位置
# train/val images/annotations
cached_file = os.path.join(tmpdir, 'annotations_trainval2014.zip')
if not os.path.exists(cached_file):
print('Downloading: "{}" to {}\n'.format(urls['annotations'], cached_file))
os.chdir(tmpdir)
subprocess.Popen('wget ' + urls['annotations'], shell=True)
os.chdir(root)
annotations_data = os.path.join(data, 'annotations')
if not os.path.exists(annotations_data):
print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=data))
command = 'unzip {} -d {}'.format(cached_file, data)
os.system(command)
print('[annotation] Done!')
tmpdir= data/data/coco /tmp/
zip 压缩包存放的位置: cached_file= data/data/coco /tmp/ annotations_trainval2014.zip
解压后文件的位置: data='data/data/coco /data/'
敲击命令行,除了网络信息之外,打印出下面,即表明数据集没有问题。
[dataset] Done!
[annotation] Done!
[json] Done!
相应pkl文件也应该放入相应文件夹内(原版代码需要把data前的/去掉,原版有/data这个目录)
DATA = 'data/data/coco'
INP_NAME = 'data/data/coco/coco_glove_word2vec.pkl'
ADJ_FILE = 'data/data/coco/coco_adj.pkl'
1.5 环境及依赖项
没有的话直接 pip install ***
- numpy
- torch-0.3.1
- torchnet
- torchvision-0.2.0
- tqdm
命令行
原版ML-GCN
lr
: learning rate 学习率lrp
: factor for learning rate of pretrained layers. The learning rate of the pretrained layers islr * lrp,预训练层的因子,需要乘以学习率
batch-size
: number of images per batchimage-size
: size of the imageepochs
: number of training epochsevaluate
: evaluate model on validation set 在验证集上进行validate,评估模型resume
: path to checkpoint 即checkpoint的路径
Demo VOC 2007
python3 demo_voc2007_gcn.py data/voc --image-size 448 --batch-size 32 -e --resume checkpoint/voc/voc_checkpoint.pth.tar
Demo COCO 2014
python3 demo_coco_gcn.py data/coco --image-size 448 --batch-size 32 -e --resume checkpoint/coco/coco_checkpoint.pth.tar
我们的代码:
python general_train.py
(torch031) [xingxiangrui@gzbh-mms-gpu55.gzbh.baidu.com chun-ML_GCN]$ python general_train.py
{'batch_size': 32,
'data': 'data/data/coco',
'device_ids': [0, 1, 2, 3],
'epoch_step': 30,
'epochs': 100,
'evaluate': False,
'image_size': 448,
'lr': 0.01,
'lrp': 0.001,
'momentum': 0.9,
'print_freq': 10,
'resume': './checkpoint/coco/exp_4/model_best_79.8707.pth.tar',
'start_epoch': 0,
'weight_decay': 1e-06,
'workers': 4}
[dataset] Done!
[annotation] Done!
[json] Done!
[dataset] Done!
[annotation] Done!
[json] Done!
Number of model parameters: 65196189
<torchvision.transforms.transforms.Compose object at 0x7f05d4c4eeb8>
=> no checkpoint found at './checkpoint/coco/exp_4/model_best_79.8707.pth.tar'
backbone learning rate 0.001
head learning rate 0.01
Epoch: [0][0/2565] Time 22.885 (22.885) Data 1.159 (1.159) Loss 0.7680 (0.7680)
Epoch: [0][10/2565] Time 1.358 (3.296) Data 0.000 (0.106) Loss 0.6504 (0.7201)
Epoch: [0][20/2565] Time 1.223 (2.334) Data 0.000 (0.056) Loss 0.5153 (0.6486)
Epoch: [0][30/2565] Time 1.185 (1.996) Data 0.000 (0.038) Loss 0.4166 (0.5850)
Epoch: [0][40/2565] Time 1.268 (1.809) Data 0.000 (0.029) Loss 0.3573 (0.5346)
用torch0.4.1到后面会报错显存不够,我们需要用torch0.3.1,
python demo_coco_gcn.py data/data/coco --image-size 448 --batch-size 32 --epochs 100 -e --resume checkpoint/coco/checkpoint.pth.tar
二、代码结构
2.1 原始ML-GCN
不同数据集上有不同的代码,我们以coco代码为准
pytorch代码训练过程基本为一个套路
- 创建参数
- 创建模型
- 加载数据
- 定义loss
- 定义optimizer
- train
2.2 可选项
模型结构
三种模型结构,hgat_fc, hgat_conv, groupnet(可以理解为baseline)
# fixme=============begin=========
if Config.MODEL == 'hgat_fc':
import mymodels.hgat_fc as hgat_fc
model = hgat_fc.HGAT_FC(Config.BACKBONE, groups=Config.GROUPS, nclasses=Config.NCLASSES,
nclasses_per_group=Config.NCLASSES_PER_GROUP,
group_channels=Config.GROUP_CHANNELS, class_channels=Config.CLASS_CHANNELS)
elif Config.MODEL == 'hgat_conv':
import mymodels.hgat_conv as hgat_conv
model = hgat_conv.HGAT_CONV(Config.BACKBONE, groups=Config.GROUPS, nclasses=Config.NCLASSES,
nclasses_per_group=Config.NCLASSES_PER_GROUP,
group_channels=Config.GROUP_CHANNELS, class_channels=Config.CLASS_CHANNELS)
elif Config.MODEL == 'groupnet':
pass
else:
raise Exception()
print('Number of model parameters: {}'.format(
sum([p.data.nelement() for p in model.parameters()])))
loss设置
可以在这三种loss之中选择一种
BCEWithLogitsLoss, MultiLabelSoftMarginLoss, DeepMarLoss
if Config.LOSS_TYPE == 'MultiLabelSoftMarginLoss':
criterion = nn.MultiLabelSoftMarginLoss()
elif Config.LOSS_TYPE == 'BCEWithLogitsLoss':
criterion = nn.BCEWithLogitsLoss()
elif Config.LOSS_TYPE == 'DeepMarLoss':
criterion = F.binary_cross_entropy_with_logits
else:
raise Exception()
2.3 模型定义位置
以MODEL = 'hgat_fc' 为准,在mymodels中hgat_fc.py之中。
import mymodels.hgat_fc as hgat_fc
model = hgat_fc.HGAT_FC(Config.BACKBONE, groups=Config.GROUPS, nclasses=Config.NCLASSES,
nclasses_per_group=Config.NCLASSES_PER_GROUP,
group_channels=Config.GROUP_CHANNELS, class_channels=Config.CLASS_CHANNELS)
其中:
class HGAT_FC(nn.Module):
def __init__(self, backbone, groups, nclasses, nclasses_per_group, group_channels, class_channels):
super(HGAT_FC, self).__init__()
三、自动断开及其解决
3.1 问题描述
Test: [510/1255] Time 0.501 (0.559) Data 0.000 (0.002) Loss 0.0975 (0.1163)
Test: [520/1255] Time 0.535 (0.559) Data 0.000 (0.002) Loss 0.1211 (0.1163)
packet_write_wait: Connection to 10.44.67.42 port 22: Broken pipe
https://blog.csdn.net/weixin_36474809/article/details/88710505
已经用这个方法设置了非自动断开,但是运行代码时候会自动断开,可能因为运行时间过长。因此我们需要重新设置代码。
3.2 运用nohup指令运行
直接更改general_train.sh文件
CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python -u demo_coco_hgat.py > ./train_logs/exp_4.log 2>&1 &
- -u参数的使用
python命令加上-u(unbuffered)参数后会强制其标准输出也同标准错误一样不通过缓存直接打印到屏幕。
CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python -u general_train.py