opengait之gaitpart粗略解析

gaitpart代码解析

opengait源码

1、从forward函数开始解析,将会遇到下面四行代码,进而解析如下

ipts, labs, _, _, seqL = inputs
sils = ipts[0]  # [n, s, h, w]
if len(sils.size()) == 4:
	sils = sils.unsqueeze(1)
del ipts

这四行代码是固定不变的,各部分作用如下注释,可以暂时视为不变量,了解即可,先不要作为重心。

# 将输入数据、标签和序列长度从 inputs 中解包出来。
ipts, labs, _, _, seqL = inputs
# 这一行代码将输入数据中的第一个张量赋值给了变量 sils。
# sils 是一个四维张量,其形状为 [n, s, h, w],其中 n 表示样本数量,s 表示序列长度,
# h 和 w 分别表示图像的高度和宽度。
sils = ipts[0]  # [n, s, h, w]
# 这一行代码检查 sils 的维度是否为 4,如果是,则使用 unsqueeze 方法在第二维上添加一个维度,
	# 将其变为五维张量。
if len(sils.size()) == 4:
	  sils = sils.unsqueeze(1)
# 删除变量ipts
del ipts

2、这四行读完之后,进入主要部分,全是需要重点理解的模块部分。

out = self.Backbone(sils)  # [n, c, s, h, w]
out = self.HPP(out)  # [n, c, s, p]
out = self.TFA(out, seqL)  # [n, c, p]
embs = self.Head(out)  # [n, c, p]

2.1、先理解第一个模块——Backbone模块

out = self.Backbone(sils)  # [n, c, s, h, w]
2.1.1、进入该方法,会找到这两句代码,主要是第一句,怎么理解呢?接着找到get_backbone方法
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
self.Backbone = SetBlockWrapper(self.Backbone)
1、model_cfg[‘backbone_cfg’]就是配置文件,可以理解为给改方法传参。
2、get_backbone方法,用于获取模型的主干网络
    def get_backbone(self, backbone_cfg):
        """Get the backbone of the model."""
        if is_dict(backbone_cfg):
            Backbone = get_attr_from([backbones], backbone_cfg['type'])
            valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
            return Backbone(**valid_args)
        if is_list(backbone_cfg):
            Backbone = nn.ModuleList([self.get_backbone(cfg)
                                      for cfg in backbone_cfg])
            return Backbone
        raise ValueError(
            "Error type for -Backbone-Cfg-, supported: (A list of) dict.")

这里会进入到BaseModel类中,这个类听名字就知道是所有模型类的基类,所有的模型将会继承自这个类。所有你暂时可以理解为,他就是把所有的模型可能都需要的方法给集成在一起了。

2.1进入这个方法,有两个if大山挡在眼前,兄弟们,该怎么办,还能怎么办,愚公移山

在两个if里面。又调用了两个函数is_dict,和is_list,继续深入了解,进入utils里面找到这两个方法

# 如果 x 是列表类型或者 ModuleList 类型中的一个,那么这个函数会返回 True,否则返回 False
def is_list(x):
    return isinstance(x, list) or isinstance(x, nn.ModuleList)
# 如果 x 是字典类型、有序字典类型或者自定义的有序字典类型中的一个,那么这个函数会返回 True,否则返回 False。
def is_dict(x):
    return isinstance(x, dict) or isinstance(x, OrderedDict) or isinstance(x, Odict)

那不就是判断传进来的参数model_cfg[‘backbone_cfg’],到底是字典类型还是列表类型。
那model_cfg[‘backbone_cfg’]到底是什么类型勒。不知丢,先放着。

OK,大致了解完这两个拦路虎之后,再看后面这两个大山
第一个

if is_dict(backbone_cfg):
  Backbone = get_attr_from([backbones], backbone_cfg['type'])
  valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
  return Backbone(**valid_args)

如果是字典类型,哦豁,又有两头豹子get_attr_from和get_valid_args拦着路的,接着干
发现,这两个也在utils里面。找到它,注意,我们的要求,现在只要求看懂,不要求写,所以只做高效的事,别傻傻的问自己为啥不会写。

# 从一个或多个源中获取指定属性的值
# 这个函数会从多个源中尝试获取指定属性的值,如果第一个源中不存在该属性,就继续从下一个源中查找,直到找到为止。
# 如果所有的源都没有该属性,则可能会引发 AttributeError 异常。
def get_attr_from(sources, name):
    try:
        return getattr(sources[0], name)
    except:
        return get_attr_from(sources[1:], name) if len(sources) > 1 else getattr(sources[0], name)

# 这个函数会根据给定对象的参数列表,从输入参数中筛选出有效的参数,并返回一个字典,其中只包含预期的参数键值对
def get_valid_args(obj, input_args, free_keys=[]):
    if inspect.isfunction(obj):
        expected_keys = inspect.getfullargspec(obj)[0]
    elif inspect.isclass(obj):
        expected_keys = inspect.getfullargspec(obj.__init__)[0]
    else:
        raise ValueError('Just support function and class object!')
    unexpect_keys = list()
    expected_args = {}
    for k, v in input_args.items():
        if k in expected_keys:
            expected_args[k] = v
        elif k in free_keys:
            pass
        else:
            unexpect_keys.append(k)
    if unexpect_keys != []:
        logging.info("Find Unexpected Args(%s) in the Configuration of - %s -" %
                     (', '.join(unexpect_keys), obj.__name__))
    return expected_args

根据上面的注释,我们大致可以了解了一下这两个函数的作用。回到整个部分

if is_dict(backbone_cfg):
  Backbone = get_attr_from([backbones], backbone_cfg['type'])
  valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
  return Backbone(**valid_args)

如果是字典类型,表示要获取单个主干模型。接下来,从 backbones 列表中获取与 ‘type’ 键指定类型相匹配的主干模型。可以查看配置文件里面主干模型的名字,是plain。这里的[backbones]刚好对应着backbones文件夹。
使用 get_attr_from 函数从 backbones 列表中获取匹配的主干模型,然后,使用 get_valid_args 函数从 backbone_cfg 中筛选出有效的参数,并传递给主干模型的构造函数,最终实例化主干模型并返回。

if is_list(backbone_cfg):
    Backbone = nn.ModuleList([self.get_backbone(cfg)
                              for cfg in backbone_cfg])

如果是列表类型,表示要获取多个主干模型。
使用列表推导式递归地调用 get_backbone 函数来获取每个子配置的主干模型,并将它们包装在 nn.ModuleList 中返回。

所以,self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])这句就理解为得到了主干模型
这是第一个模块,明天在接着分析下一个模块

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值