EGNet---code(三)训练深入分析

run.py 
# =>1 get_loader

        train_loader, dataset = get_loader

# =>2 config.save_fold

config.save_fold = './EGNet/run-nnet'

# =>3 train = Solver()

train = Solver(train_loader, None, config)

# =>4 train.train()

train.train()

首先分析
# =>1 get_loader

from dataset import get_loader
dataset.py

def get_loader(batch_size, mode='train', num_thread=1, test_mode=0, sal_mode='e'):
    # todo =>arg// get_loader(batch=1, test, 4, 1, e)
    shuffle = False
    if mode == 'train':
        shuffle = True
        dataset = ImageDataTrain()
    data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_thread)
    return data_loader, dataset

分析
# =>1 get_loader中的 
ImageDataTrain()读入图片:jpg,png,edge

如下,
sal_root       训练数据集位置
sal_source  .lst文件
逐行读取,
计算所有数量


class ImageDataTrain(data.Dataset):
    def __init__(self):
        # self.sal_root = '/home/liuj/dataset/DUTS/DUTS-TR'
        self.sal_root = './DUTS-TR'
        # self.sal_source = '/home/liuj/dataset/DUTS/DUTS-TR/train_pair_edge.lst'
        self.sal_source = './DUTS-TR/train_pair_edge.lst'
 
        with open(self.sal_source, 'r') as f:
            self.sal_list = [x.strip() for x in f.readlines()]  # readline by line
 
        self.sal_num = len(self.sal_list)  # save the num of the sal in the list
 

self.sal_root = './DUTS-TR'
self.sal_source = './DUTS-TR/train_pair_edge.lst'

train_pair_edge.lst文件如下

DUTS-TR-Image/ILSVRC2012_test_00000018.jpg     self.sal_list[item].split()[0])
DUTS-TR-Mask/ILSVRC2012_test_00000018.png     self.sal_list[item].split()[1])
DUTS-TR-Mask/ILSVRC2012_test_00000018_edge.png。self.sal_list[item].split()[2])

下面是
self.sal_root + self.sal_list  = './DUTS-TR/‘ + self.sal_list

# load_image ('./DUTS-TR/DUTS-TR-Image/ILSVRC2012_test_00000018.jpg'
sal_image = load_image(os.path.join(self.sal_root, self.sal_list[item % self.sal_num].split()[0]))
# load_image ('./DUTS-TR/DUTS-TR-Mask/ILSVRC2012_test_00000018.png'
sal_label = load_sal_label(os.path.join(self.sal_root, self.sal_list[item % self.sal_num].split()[1]))
# load_image ('./DUTS-TR/DUTS-TR-Mask/ILSVRC2012_test_00000018.png'
sal_edge = load_edge_label(os.path.join(self.sal_root, self.sal_list[item % self.sal_num].split()[2]))

random_flip随机反转

读入张量
sal_image = torch.Tensor(sal_image)
sal_label = torch.Tensor(sal_label)
sal_edge = torch.Tensor(sal_edge)

    def __getitem__(self, item):
        # load_image ('./DUTS-TR/DUTS-TR-Image/ILSVRC2012_test_00000018.jpg')
        sal_image = load_image(os.path.join(self.sal_root, self.sal_list[item % self.sal_num].split()[0]))
        # load_image ('./DUTS-TR/DUTS-TR-Mask/ILSVRC2012_test_00000018.png')
        sal_label = load_sal_label(os.path.join(self.sal_root, self.sal_list[item % self.sal_num].split()[1]))
        # load_image ('./DUTS-TR/DUTS-TR-Mask/ILSVRC2012_test_00000018.png')
        sal_edge = load_edge_label(os.path.join(self.sal_root, self.sal_list[item % self.sal_num].split()[2]))

        sal_image, sal_label, sal_edge = cv_random_flip(sal_image, sal_label, sal_edge)  # random_flip

        sal_image = torch.Tensor(sal_image)
        sal_label = torch.Tensor(sal_label)
        sal_edge = torch.Tensor(sal_edge)
 
        sample = {'sal_image': sal_image, 'sal_label': sal_label, 'sal_edge': sal_edge}
        return self.sal_num

# =>3 train = Solver()
train = Solver(train_loader, None, config)

=> solve.py文件

class Solver(object):
    def __init__(self, train_loader, test_loader, config, save_fold=None):
             self.train_loader = train_loader  # input
             self.config = config  # input
             self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255.  # todo =>?//把rgb提前
             self.build_model()  # todo =>def//
             if config.mode == 'train':
                    self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')

self.build_model() 
调用的是solve.py文件中的def build_model()函数,如下:

    # build the network
    def build_model(self):
        self.net_bone = build_model(base_model_cfg)  # from model import build_model , base_model_cfg = 'resnet'
        self.net_bone.eval()  # use_global_stats = True  # todo =>?//
        self.net_bone.apply(weights_init)
        if self.config.mode == 'train':  # choose//
            if self.config.load_bone == '':  # default=''
                if base_model_cfg == 'resnet':  # base_model_cfg = 'resnet'
                    self.net_bone.base.load_state_dict(torch.load(self.config.resnet, map_location=torch.device('cpu')))  # todo =>cuda//
        self.lr_bone = p['lr_bone']
        self.lr_branch = p['lr_branch']
        self.optimizer_bone = Adam(filter(lambda p: p.requires_grad, self.net_bone.parameters()), lr=self.lr_bone,
                                   weight_decay=p['wd'])
        '''
        p['lr_bone'] = 5e-5  # Learning rate resnet:5e-5, vgg:2e-5
        p['lr_branch'] = 0.025  # Learning rate
        p['wd'] = 0.0005  # Weight decay
        p['momentum'] = 0.90  # Momentum
        '''
        self.print_network(self.net_bone, 'trueUnify bone part')  # todo =>def//

def build_model()函数里面的
       def build_model(self):
              self.net_bone = build_model(base_model_cfg)

这里的build_model(base_model_cfg)
from model import build_model

=> model.py文件

# build the whole network
def build_model(base_model_cfg='vgg'):
    elif base_model_cfg == 'resnet':
        return TUN_bone(base_model_cfg, *extra_layer(base_model_cfg, resnet50()))

->1

TUN_bone(base_model_cfg, *extra_layer(base_model_cfg, resnet50()))是关键函数

*extra_layer(base_model_cfg, resnet50()
应该是指的调用函数extra_layer()函数的返回值return vgg, merge1_layers, merge2_layers用于输入

->2

class TUN_bone(nn.Module):
    def __init__(self, base_model_cfg, base, merge1_layers, merge2_layers):  # vgg, merge1_layers, merge2_layers
        super(TUN_bone, self).__init__()  # use the __init__ from father class
        self.base_model_cfg = base_model_cfg

        elif self.base_model_cfg == 'resnet':
            self.convert = ConvertLayer(config_resnet['convert'])
            self.base = base
            self.merge1 = merge1_layers
            self.merge2 = merge2_layers

    def forward(self, x):
        x_size = x.size()[2:]
        conv2merge = self.base(x)        
        if self.base_model_cfg == 'resnet':            
            conv2merge = self.convert(conv2merge)
        up_edge, edge_feature, up_sal, sal_feature = self.merge1(conv2merge, x_size)
        up_sal_final = self.merge2(edge_feature, sal_feature, x_size)
        return up_edge, up_sal, up_sal_final


# TUN network
class TUN_bone(nn.Module):
    def __init__(self, base_model_cfg, base, merge1_layers, merge2_layers):
        super(TUN_bone, self).__init__()  # use the __init__ from father class
        self.base_model_cfg = base_model_cfg
        if self.base_model_cfg == 'vgg':

            self.base = base
            # self.base_ex = nn.ModuleList(base_ex)
            self.merge1 = merge1_layers
            self.merge2 = merge2_layers

        elif self.base_model_cfg == 'resnet':
            self.convert = ConvertLayer(config_resnet['convert'])
            self.base = base
            self.merge1 = merge1_layers
            self.merge2 = merge2_layers

    def forward(self, x):
        x_size = x.size()[2:]
        conv2merge = self.base(x)        
        if self.base_model_cfg == 'resnet':            
            conv2merge = self.convert(conv2merge)
        up_edge, edge_feature, up_sal, sal_feature = self.merge1(conv2merge, x_size)
        up_sal_final = self.merge2(edge_feature, sal_feature, x_size)
        return up_edge, up_sal, up_sal_final

->3

def extra_layer(base_model_cfg, vgg):

# extra part
def extra_layer(base_model_cfg, vgg):
    if base_model_cfg == 'vgg':
        config = config_vgg
    elif base_model_cfg == 'resnet':
        config = config_resnet
    merge1_layers = MergeLayer1(config['merge1'])
    merge2_layers = MergeLayer2(config['merge2'])

    return vgg, merge1_layers, merge2_layers

config_vgg = {
'convert': [[128,256,512,512,512],[64,128,256,512,512]],
'merge1': [[128, 256, 128, 3,1], [256, 512, 256, 3, 1], [512, 0, 512, 5, 2], [512, 0, 512, 5, 2],[512, 0, 512, 7, 3]],
'merge2': [[128], [256, 512, 512, 512]]} 
# no convert layer, no conv6

config_resnet = {
'convert': [[64,256,512,1024,2048],[128,256,512,512,512]],
'deep_pool': [[512, 512, 256, 256, 128], [512, 256, 256, 128, 128], [False, True, True, True, False], [True, True, True, True, False]],
'score': 256,
'edgeinfo':[[16, 16, 16, 16], 128, [16,8,4,2]],
'edgeinfoc':[64,128],
'block': [[512, [16]], [256, [16]], [256, [16]], [128, [16]]],
'fuse': [[16, 16, 16, 16], True],
'fuse_ratio': [[16,1], [8,1], [4,1], [2,1]], 
'merge1': [[128, 256, 128, 3,1], [256, 512, 256, 3, 1], [512, 0, 512, 5, 2], [512, 0, 512, 5, 2],[512, 0, 512, 7, 3]],
'merge2': [[128], [256, 512, 512, 512]]}

 

 

 

 

 

 

 

 

 

# =>4 train.train()
train.train()
这里主要讲了如何设计loss

 

train = Solver(train_loader, None, config)

class Solver(object):
    def __init__(self, train_loader, test_loader, config, save_fold=None):  # train = Solver(train_loader, None, config)
        self.train_loader = train_loader  # input
        self.config = config  # input
        self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255.  # todo =>?//把rgb提前
        # inference: choose the side map (see paper)
        self.build_model()  # from model import build_model, weights_init
        if config.mode == 'train':
            self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')

class Solver(object):
    def __init__(self, train_loader, test_loader, config, save_fold=None):  # train = Solver(train_loader, None, config)
        '''
        :param train_loader: train_loader
        :param test_loader:  None
        :param config:       config
        :param save_fold:    None
        '''
        self.train_loader = train_loader  # input
        self.config = config  # input
        self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255.  # todo =>?//把rgb提前
        # inference: choose the side map (see paper)
        self.build_model()  # todo =>def//
        if config.mode == 'train':
            self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')

 

 

 

 

 

 

 

 

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

计算机视觉-Archer

图像分割没有团队的同学可加群

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值