ST-GCN源码分析

本文主要分析ST-GCN的源码实现,包括静态和动态代码流的运行逻辑,重点解析`graph.py`中的`get_edge`、`get_hop_distance`、`get_adjacency`函数以及`st-gcn.py`中的GCN和TCN模块。ST-GCN网络输入为(N,C,T,V,M)张量,通过GCN和TCN进行空间和时间域的特征提取。邻接矩阵的构建和节点分组策略在`graph.py`中实现,而网络结构和参数更新集中在`st-gcn.py`。" 111292980,10295056,Python Treeview控件与BeautifulSoup select方法解析,"['Python库', '前端开发', 'HTML解析']
摘要由CSDN通过智能技术生成

在上篇的blog中,写了一下对于ST-GCN论文的分析ST-GCN论文分析_Eric加油学!的博客-CSDN博客,这篇blog写一下对于ST-GCN源码的理解和整理,参考了一些写的比较好的文章,在文末附上链接。

文末有更新的ST-GCN复现全过程(详细)

目录

整体的运行逻辑

静态代码流

动态代码流

核心代码分析

graph.py

def get_edge(self,layout): 

def get_hop_distance(num_node,edge,max_hop=1):

def get_adjacency(self,strategy):

st-gcn.py

 GCN模块

TCN模块


整体的运行逻辑

静态代码流

主程序  st-gcn-master/main.py

#主程序  st-gcn-master/main.py
#调用不同类文件 (recognition类、demo类)

#a、导入recognition类
processors['recognition'] = import_class('processor.recognition.REC_Processor')

#b、调取recognition中默认参数
subparsers.add_parser(k, parents=[p.get_parser()]) 

#c、接受命令行中参数
arg = parser.parse_args()

#d、实例化recognition类并传入命令行中参数(同时完成类初始化)
Processor = processors[arg.processor]
p = Processor(sys.argv[2:]) 

#e、调用recognition类中开始函数
p.start()

类程序1:recognition类        st-gcn-master/processor/recognition.py

def weights_init(m):    #权重初始化


class REC_Processor(Processor):
    def load_model(self):        #加载模型

    def load_optimizer(self):    #加载优化器

    def adjust_lr(self):         #调整学习率
    
    def show_topk(self,k):        #显示精度

    def train(self):

    def test(self,evaluation = True):

    def get_parser(add_help = False):

类程序2:processor类        st-gcn-master/processor/processor.py

class Processor(IO):
    def __init__(self,argv=None):
    
    def init_environment(self):

    def load_optimizer(self):

    def load_data(self):

    def show_epoch_info(self):

    def show_iter_info(self):

    def train(self):

    def test(self):

    def start(self):

    def get_parser(add_help = False):

类程序3:IO类         st-gcn-master/processor/io.py

class IO():
    def __init__(self,argv=None):

    def load_arg(self,argv=None):

    def init_environment(self):

    def load_model(self):

    def load_weights(self):

    def gpu(self):

    def start(self):

    def get_parser(add_help=False):

动态代码流

以NTU交叉主题模型训练为例:

当在终端输入命令: python main.py recognition -c config/st_gcn/ntu-xsub/train.yaml后,执行主程序

#a、导入recognition类
processors['recognition'] = import_class('processor.recognition.REC_Processor')
#b、调取recognition中默认参数
subparsers.add_parser(k, parents=[p.get_parser()])     
#    --->    def get_parser(add_help=False):
#c、接受命令行中参数
arg = parser.parse_args()

#d、实例化recognition类并传入命令行中参数(同时完成类初始化)
Processor = processors[arg.processor]    #arg:processor:recognition
p = Processor(sys.argv[2:])   #sys.argv[2:]:-c config/st_gcn/ntu-xsub/train.yaml

其中实例化和初始化过程如下: 

# processor/processor.py
class Processor(IO):
    def __init__(self,argv=None):
        self.load_argv(argv)    #参数加载,得到self.arg
            # --> def load_arg(self,argv=None):
            '''
            1、读取默认参数到参数表
            2、使用输入参数更新参数表
            3、读取参数配置文件更新参数表 配置文件:ntu-xsub/train.yaml
            4、使用输入参数更新参数表
            '''
   
        self.init_environment():
            super().init_environment() #继承调用 processor/io.py
            '''
                self.io=        获取自定义包中self.io类
                self.io.save_arg(self.arg)    将参数表保存到工作区配置文件
                如果使用GPU:获取GPU号和设备号
            '''
            #添加定义类参数
            self.result = dict()
            self.iter_info = dict()
            self.epoch_info = dict()
            self.meta_info = dict(epoch=0, iter=0)

        self.load_model()    # --> recognition.py
            def load_model(self):
                self.model =     #下载模型,获得模型self.model
                #模型文件: /net/st_gcn.py
                self.model.apply(weights_init)  #权重初始化,见def weights_init(m):
                self.loss = nn.CrossEntropyLoss()   #定义交叉商为损失函数

        self.load_weights():
        
        self.gpu():
        '''
            def gpu(self):  
                将self.arg、self.io等放到gpus上
                如果使用gpu且数量大于1,使模型并行
        '''

        self.load_data():
            #--> def load_data(self):
            ''
ST-GCN(Spatial Temporal Graph Convolutional Networks)是一种用于人体动作识别的深度学习模型,其源码解析可以分为以下几个方面。 首先,ST-GCN是基于图卷积神经网络GCN)的一种扩展模型,在处理视频序列时,将每一帧的姿势数据(通常使用OpenPose进行姿势估计)建模为图结构,其中节点对应关键点,边表示节点之间的空间关系。源码中主要包含了构建图结构的代码,包括节点的定义、边的连接方式以及图结构的转换。 其次,ST-GCN引入了时序关系建模,以利用动作序列的时间信息。源码中涉及到的关键部分是时序卷积层的实现,对于每一个节点,通过聚合邻居节点的特征信息来更新当前节点的特征表示,从而实现对时序关系的建模。此外,还包括了一些预处理方法,如时间差分和层间残差等,用于增强模型的表达能力。 再次,ST-GCN还包含了一些辅助模块,用于提取更丰富的时空特征。例如,源码中提供了一个ST-GCN的变种模型,引入了多尺度特征融合的机制,通过将不同尺度的特征进行融合,提高了模型的鲁棒性和泛化能力。 最后,源码中还包括了一些训练和测试的相关代码,用于对ST-GCN模型进行训练和评估。这部分代码主要包括了数据加载、模型的构建、损失函数的定义以及优化器的选择等。 总之,ST-GCN源码解析涉及了构建图结构、时序关系建模、辅助模块和训练测试等方面,通过对这些代码的解析,可以深入理解ST-GCN模型的原理和实现细节。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值