在上篇的blog中,写了一下对于ST-GCN论文的分析ST-GCN论文分析_Eric加油学!的博客-CSDN博客,这篇blog写一下对于ST-GCN源码的理解和整理,参考了一些写的比较好的文章,在文末附上链接。
文末有更新的ST-GCN复现全过程(详细)
目录
def get_hop_distance(num_node,edge,max_hop=1):
def get_adjacency(self,strategy):
整体的运行逻辑
静态代码流
主程序 st-gcn-master/main.py
#主程序 st-gcn-master/main.py
#调用不同类文件 (recognition类、demo类)
#a、导入recognition类
processors['recognition'] = import_class('processor.recognition.REC_Processor')
#b、调取recognition中默认参数
subparsers.add_parser(k, parents=[p.get_parser()])
#c、接受命令行中参数
arg = parser.parse_args()
#d、实例化recognition类并传入命令行中参数(同时完成类初始化)
Processor = processors[arg.processor]
p = Processor(sys.argv[2:])
#e、调用recognition类中开始函数
p.start()
类程序1:recognition类 st-gcn-master/processor/recognition.py
def weights_init(m): #权重初始化
class REC_Processor(Processor):
def load_model(self): #加载模型
def load_optimizer(self): #加载优化器
def adjust_lr(self): #调整学习率
def show_topk(self,k): #显示精度
def train(self):
def test(self,evaluation = True):
def get_parser(add_help = False):
类程序2:processor类 st-gcn-master/processor/processor.py
class Processor(IO):
def __init__(self,argv=None):
def init_environment(self):
def load_optimizer(self):
def load_data(self):
def show_epoch_info(self):
def show_iter_info(self):
def train(self):
def test(self):
def start(self):
def get_parser(add_help = False):
类程序3:IO类 st-gcn-master/processor/io.py
class IO():
def __init__(self,argv=None):
def load_arg(self,argv=None):
def init_environment(self):
def load_model(self):
def load_weights(self):
def gpu(self):
def start(self):
def get_parser(add_help=False):
动态代码流
以NTU交叉主题模型训练为例:
当在终端输入命令: python main.py recognition -c config/st_gcn/ntu-xsub/train.yaml后,执行主程序
#a、导入recognition类
processors['recognition'] = import_class('processor.recognition.REC_Processor')
#b、调取recognition中默认参数
subparsers.add_parser(k, parents=[p.get_parser()])
# ---> def get_parser(add_help=False):
#c、接受命令行中参数
arg = parser.parse_args()
#d、实例化recognition类并传入命令行中参数(同时完成类初始化)
Processor = processors[arg.processor] #arg:processor:recognition
p = Processor(sys.argv[2:]) #sys.argv[2:]:-c config/st_gcn/ntu-xsub/train.yaml
其中实例化和初始化过程如下:
# processor/processor.py
class Processor(IO):
def __init__(self,argv=None):
self.load_argv(argv) #参数加载,得到self.arg
# --> def load_arg(self,argv=None):
'''
1、读取默认参数到参数表
2、使用输入参数更新参数表
3、读取参数配置文件更新参数表 配置文件:ntu-xsub/train.yaml
4、使用输入参数更新参数表
'''
self.init_environment():
super().init_environment() #继承调用 processor/io.py
'''
self.io= 获取自定义包中self.io类
self.io.save_arg(self.arg) 将参数表保存到工作区配置文件
如果使用GPU:获取GPU号和设备号
'''
#添加定义类参数
self.result = dict()
self.iter_info = dict()
self.epoch_info = dict()
self.meta_info = dict(epoch=0, iter=0)
self.load_model() # --> recognition.py
def load_model(self):
self.model = #下载模型,获得模型self.model
#模型文件: /net/st_gcn.py
self.model.apply(weights_init) #权重初始化,见def weights_init(m):
self.loss = nn.CrossEntropyLoss() #定义交叉商为损失函数
self.load_weights():
self.gpu():
'''
def gpu(self):
将self.arg、self.io等放到gpus上
如果使用gpu且数量大于1,使模型并行
'''
self.load_data():
#--> def load_data(self):
''