gaitpart代码解析
opengait源码
1、从forward函数开始解析,将会遇到下面四行代码,进而解析如下
ipts, labs, _, _, seqL = inputs
sils = ipts[0] # [n, s, h, w]
if len(sils.size()) == 4:
sils = sils.unsqueeze(1)
del ipts
这四行代码是固定不变的,各部分作用如下注释,可以暂时视为不变量,了解即可,先不要作为重心。
# 将输入数据、标签和序列长度从 inputs 中解包出来。
ipts, labs, _, _, seqL = inputs
# 这一行代码将输入数据中的第一个张量赋值给了变量 sils。
# sils 是一个四维张量,其形状为 [n, s, h, w],其中 n 表示样本数量,s 表示序列长度,
# h 和 w 分别表示图像的高度和宽度。
sils = ipts[0] # [n, s, h, w]
# 这一行代码检查 sils 的维度是否为 4,如果是,则使用 unsqueeze 方法在第二维上添加一个维度,
# 将其变为五维张量。
if len(sils.size()) == 4:
sils = sils.unsqueeze(1)
# 删除变量ipts
del ipts
2、这四行读完之后,进入主要部分,全是需要重点理解的模块部分。
out = self.Backbone(sils) # [n, c, s, h, w]
out = self.HPP(out) # [n, c, s, p]
out = self.TFA(out, seqL) # [n, c, p]
embs = self.Head(out) # [n, c, p]
2.1、先理解第一个模块——Backbone模块
out = self.Backbone(sils) # [n, c, s, h, w]
2.1.1、进入该方法,会找到这两句代码,主要是第一句,怎么理解呢?接着找到get_backbone方法
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
self.Backbone = SetBlockWrapper(self.Backbone)
1、model_cfg[‘backbone_cfg’]就是配置文件,可以理解为给改方法传参。
2、get_backbone方法,用于获取模型的主干网络
def get_backbone(self, backbone_cfg):
"""Get the backbone of the model."""
if is_dict(backbone_cfg):
Backbone = get_attr_from([backbones], backbone_cfg['type'])
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
return Backbone(**valid_args)
if is_list(backbone_cfg):
Backbone = nn.ModuleList([self.get_backbone(cfg)
for cfg in backbone_cfg])
return Backbone
raise ValueError(
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
这里会进入到BaseModel类中,这个类听名字就知道是所有模型类的基类,所有的模型将会继承自这个类。所有你暂时可以理解为,他就是把所有的模型可能都需要的方法给集成在一起了。
2.1进入这个方法,有两个if大山挡在眼前,兄弟们,该怎么办,还能怎么办,愚公移山
在两个if里面。又调用了两个函数is_dict,和is_list,继续深入了解,进入utils里面找到这两个方法
# 如果 x 是列表类型或者 ModuleList 类型中的一个,那么这个函数会返回 True,否则返回 False
def is_list(x):
return isinstance(x, list) or isinstance(x, nn.ModuleList)
# 如果 x 是字典类型、有序字典类型或者自定义的有序字典类型中的一个,那么这个函数会返回 True,否则返回 False。
def is_dict(x):
return isinstance(x, dict) or isinstance(x, OrderedDict) or isinstance(x, Odict)
那不就是判断传进来的参数model_cfg[‘backbone_cfg’],到底是字典类型还是列表类型。
那model_cfg[‘backbone_cfg’]到底是什么类型勒。不知丢,先放着。
OK,大致了解完这两个拦路虎之后,再看后面这两个大山
第一个
if is_dict(backbone_cfg):
Backbone = get_attr_from([backbones], backbone_cfg['type'])
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
return Backbone(**valid_args)
如果是字典类型,哦豁,又有两头豹子get_attr_from和get_valid_args拦着路的,接着干
发现,这两个也在utils里面。找到它,注意,我们的要求,现在只要求看懂,不要求写,所以只做高效的事,别傻傻的问自己为啥不会写。
# 从一个或多个源中获取指定属性的值
# 这个函数会从多个源中尝试获取指定属性的值,如果第一个源中不存在该属性,就继续从下一个源中查找,直到找到为止。
# 如果所有的源都没有该属性,则可能会引发 AttributeError 异常。
def get_attr_from(sources, name):
try:
return getattr(sources[0], name)
except:
return get_attr_from(sources[1:], name) if len(sources) > 1 else getattr(sources[0], name)
# 这个函数会根据给定对象的参数列表,从输入参数中筛选出有效的参数,并返回一个字典,其中只包含预期的参数键值对
def get_valid_args(obj, input_args, free_keys=[]):
if inspect.isfunction(obj):
expected_keys = inspect.getfullargspec(obj)[0]
elif inspect.isclass(obj):
expected_keys = inspect.getfullargspec(obj.__init__)[0]
else:
raise ValueError('Just support function and class object!')
unexpect_keys = list()
expected_args = {}
for k, v in input_args.items():
if k in expected_keys:
expected_args[k] = v
elif k in free_keys:
pass
else:
unexpect_keys.append(k)
if unexpect_keys != []:
logging.info("Find Unexpected Args(%s) in the Configuration of - %s -" %
(', '.join(unexpect_keys), obj.__name__))
return expected_args
根据上面的注释,我们大致可以了解了一下这两个函数的作用。回到整个部分
if is_dict(backbone_cfg):
Backbone = get_attr_from([backbones], backbone_cfg['type'])
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
return Backbone(**valid_args)
如果是字典类型,表示要获取单个主干模型。接下来,从 backbones 列表中获取与 ‘type’ 键指定类型相匹配的主干模型。可以查看配置文件里面主干模型的名字,是plain。这里的[backbones]刚好对应着backbones文件夹。
使用 get_attr_from 函数从 backbones 列表中获取匹配的主干模型,然后,使用 get_valid_args 函数从 backbone_cfg 中筛选出有效的参数,并传递给主干模型的构造函数,最终实例化主干模型并返回。
if is_list(backbone_cfg):
Backbone = nn.ModuleList([self.get_backbone(cfg)
for cfg in backbone_cfg])
如果是列表类型,表示要获取多个主干模型。
使用列表推导式递归地调用 get_backbone 函数来获取每个子配置的主干模型,并将它们包装在 nn.ModuleList 中返回。
所以,self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
这句就理解为得到了主干模型
这是第一个模块,明天在接着分析下一个模块