【网络模块】用GCN计算两个属性图的相似度

背景

  • 遥感图像中各地理对象的特征总是与其邻接对象有密切的关系,期望充分利用地理对象的空间关系
  • 想把地理对象间的空间关系嵌入到遥感对象识别网络中,计算对象与邻接对象构成的场景图与标签图之间的相似度
  • 计算图相似度的函数有很多,由于实验过程中觉得图核函数(Graph Kernel)难以满足复杂的地理环境,因此考虑采用GCN来计算

因此,在如图所示的场景中,以船或者海岸为中心对象,构建其与邻接对象的图结构,并通过Backbone以及对象掩码提取对象特征后,计算两个属性图的相似度。

在这里插入图片描述

思路

  1. 用GCN根据节点特征和邻接矩阵做特征变换,统一维度
  2. 将两个图的特征在通道维度拼接,通过线性层得到1维相似度值
  3. 每个训练样本对应n个地理对象,跟邻接对象一起在遥感场景图以及标签图中构成n对图数据,遍历求解

代码

Step1 导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

Step2 GCN

class GraphConvolution(nn.Module):
    """
    inputs: node_features(n,d) 、 adj(n,n)
    outputs: node_feature(n,d)
    """
    def __init__(self, in_features, out_features):
        super(GraphConvolution, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU()

    def forward(self, x, adj):
        # x: Node features (num_nodes, in_features)
        # adj_matrix: Adjacency matrix (num_nodes, num_nodes)
        x = self.linear(x)

        # Normalize by the degree matrix (optional step)
        degree_matrix = torch.sum(adj, dim=1, keepdim=True)
        x = x / degree_matrix

        # Apply activation function (e.g., ReLU)
        x = self.relu(x)

        # Perform graph propagation
        x = torch.matmul(adj, x)
        
        return x

Step3 计算两个图的相似度

class GraphSimilarity(nn.Module):
    """
    计算两个图的相似度
    inputs:
        node_features_left: (n,d1)
        node_features_right:(n,d2)
        adj:(n,n)
    outputs:
        similarity: (n,1)
    """
    def __init__(self, num_features_left, num_features_right, hid_features):
        super(GraphSimilarity, self).__init__()
        self.gcn1_left = GraphConvolution(num_features_left, hid_features)
        self.gcn2_left = GraphConvolution(hid_features, hid_features)
        
        self.gcn1_right = GraphConvolution(num_features_right, hid_features)
        self.gcn2_right = GraphConvolution(hid_features, hid_features)
        
        self.linear = nn.Linear(hid_features * 2, 1)
        
    def forward(self, node_features_left, node_features_right, adj):
        left = self.gcn1_left(node_features_left, adj)
        left = self.gcn2_left(left, adj)
        
        right = self.gcn1_right(node_features_right, adj)
        right = self.gcn2_right(right, adj)
        
        
        # similarity = F.cosine_similarity(left, right, dim=-1) #余弦相似度
        # similarity = torch.sum(left * right) / (torch.norm(left) * torch.norm(right))# 点积
        
		# 将两个图的嵌入向量拼接
        similarity = self.linear(torch.cat((left, right), dim=-1)).mean()
        return similarity

Step4 遍历单个样本对象的邻接图

def calculate_similarity(model, feature1, feature2, adj):
    """
    输入:节点特征1、节点特征2(节点标签)、邻接矩阵
    :param feature1: tensor (n, d1) 此处n表示一个样本中的节点个数,而非其邻接节点个数 (模型输出)
    :param feature2: tensor (n, d2) 此处n表示一个样本中的节点个数,而非其邻接节点个数 (obj标签)
    :param adj: tensor (n,n)
    :return: tensor (n,1)
    """
    
    "(1)获取当前对象的邻接对象,构成图,n个对象构成n个图"
    "(2)把当前图和标签图送入网络算相似度"
    
    sim_matrix = torch.zeros(size=(feature1.shape[0],))
    # 遍历每个对象
    for i in range(feature1.shape[0]):
        indices = torch.where(adj[i] == 1)[0] # 拿到邻接对象索引
        
        fea1 = feature1[indices,:]
        fea2 = feature2[indices,:]
        adj_ = adj[indices,:][:,indices].float() # 当前节点的邻接矩阵
        
        sim = model(fea1, fea2, adj_)
        sim_matrix[i] = sim
    
    return sim_matrix

Step5 测试结果

if __name__ == '__main__':
    n = 20 # 对象个数
    d1 = 64 # 图1节点特征维度
    d2 = 1  # 图2节点特征维度
    embedding_size = 128 # 统一维度
    outputs = torch.randn(size=(n,d1))
    labels = torch.randn(size=(n,d2))
    adj = torch.randint(0,2,size=(n,n))
    model = GraphSimilarity(d1, d2, embedding_size)
    
    sim = calculate_similarity(model,  outputs, labels, adj)
    print("Graph Similarity:", sim)
    print(sim.shape)

在这里插入图片描述
n个对象得到长度为n的向量,其中每个值表示由当前节点与邻接节点组成的图1(例如遥感场景图)和图2(例如对象标签图)的相似度

模型结构

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1               [1, 20, 128]           8,320
              ReLU-2               [1, 20, 128]               0
  GraphConvolution-3               [1, 20, 128]               0
            Linear-4               [1, 20, 128]          16,512
              ReLU-5               [1, 20, 128]               0
  GraphConvolution-6               [1, 20, 128]               0
            Linear-7               [1, 20, 128]             256
              ReLU-8               [1, 20, 128]               0
  GraphConvolution-9               [1, 20, 128]               0
           Linear-10               [1, 20, 128]          16,512
             ReLU-11               [1, 20, 128]               0
 GraphConvolution-12               [1, 20, 128]               0
           Linear-13                 [1, 20, 1]             257
================================================================
Total params: 41,857
Trainable params: 41,857
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 39.06
Forward/backward pass size (MB): 0.23
Params size (MB): 0.16
Estimated Total Size (MB): 39.46
----------------------------------------------------------------

小结

  • 记录一下简单用GCN算场景图和对应标签图相似度的模块
  • 由于每个样本中地理对象个数不同,且满足条件的对象个数也不同,因此在模型输入的时候就去掉了batchsize维度,即每次都是单张样本送入网络计算相似度。
  • 还有很多更复杂的网络结构用于类似的需求,有待进一步尝试
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值