KGTN论文+代码运行

在这里插入图片描述

运行main_KGTN.py:

主要运行流程:

  1. 定义一个adjacent_matrix
    ndarray [1000,1000] # 针对车辆重识别VERI776应该是[576,576]
  2. 定义一个KGTN模型
import numpy as np
import math
import torch
import torch.nn.functional as F
import torch.nn as nn

class KGTN(nn.Module):
    def __init__(self, 
                 feature_dim, 
                 num_classes,
                 use_all_base,
                 use_knowledge_propagation,
                 ggnn_time_step=None,
                 pretrain=False,
                 pretrain_model=None,
                 graph_learnable=False,
                 classifier_type='inner_product',
                 adjacent_matrix=None):
        super(KGTN, self).__init__()

        self.feature_dim = feature_dim
        self.use_knowledge_propagation = use_knowledge_propagation
        self.use_all_base = use_all_base
        self.ggnn_time_step = ggnn_time_step

        self.last_fc_weight = nn.Parameter(torch.rand(feature_dim, num_classes))

        if use_knowledge_propagation:
            self.ggnn = KGTM(
                num_nodes = num_classes, 
                use_all_base = use_all_base,
                hidden_state_channel = feature_dim,
                output_channel = feature_dim,
                time_step = self.ggnn_time_step,
                adjacent_matrix = adjacent_matrix,
                graph_learnable=graph_learnable 
            )
        # initialize parameters and load pretrain
        self.param_init()
        self.load_pretrain(pretrain_model, pretrain)
        
        assert classifier_type in ['inner_product', 'cosine', 'pearson']
        self.classifier_type=classifier_type
        if self.classifier_type == 'cosine' or self.classifier_type == 'pearson':
            init_scale_cls = 10
            self.scale_cls = nn.Parameter(
                torch.FloatTensor(1).fill_(init_scale_cls),
                requires_grad=True)

    def forward(self, input):
        if self.use_knowledge_propagation:
            step_fc_weight = self.ggnn(self.last_fc_weight.transpose(0, 1).unsqueeze(0))
            weight = step_fc_weight[-1]
            weight = weight.squeeze().transpose(0, 1)
            if self.classifier_type == 'cosine':
                # cos sim
                input = F.normalize(input, p=2, dim=1, eps=1e-12)
                weight = F.normalize(weight, p=2, dim=0, eps
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值