利用图卷积网络(GCN)来学习整个视频之间的时间相关性

import torch
import torch.nn as nn
import torch.nn.functional as F

class MEGAN(nn.Module):
    def __init__(self, in_channels, out_channels, num_frames):
        super(MEGAN, self).__init__()

        self.num_frames = num_frames

        # Graph construction
        self.spatial_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.temporal_conv = nn.Conv2d(1, 1, kernel_size=(num_frames, 1), stride=1, padding=0, bias=False)

        # Graph convolution layers
        self.gcn_layers = nn.ModuleList()
        for i in range(4):
            self.gcn_layers.append(GCN(out_channels, out_channels))

        # Memory enhancement module
        self.mem_conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.mem_conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.mem_conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)

        # Output layers
        self.out_conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.out_conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.out_conv3 = nn.Conv2d(out_channels, in_channels * num_frames, kernel_size=3, stride=1, padding=1, bias=True)

    def forward(self, x):
        # Graph construction
        b, t, c, h, w = x.size()
        x = x.view(-1, c, h, w)
        spatial_x = self.spatial_conv(x)
        spatial_x = spatial_x.view(b, t, -1, spatial_x.size(2), spatial_x.size(3))
        temporal_x = spatial_x.transpose(1, 2).contiguous().view(b, -1, self.num_frames, spatial_x.size(3), spatial_x.size(4))
        temporal_x = temporal_x.mean(1, keepdim=True)
        temporal_x = self.temporal_conv(temporal_x)
        temporal_x = temporal_x.repeat(1, spatial_x.size(2), 1, 1, 1)
        spatial_x = spatial_x.view(-1, out_channels, spatial_x.size(3), spatial_x.size(4))

        # Graph convolution layers
        for gcn_layer in self.gcn_layers:
            spatial_x = gcn_layer(spatial_x, temporal_x)

        # Memory enhancement module
        mem_x = self.mem_conv1(spatial_x)
        mem_x = F.relu(mem_x)
       

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值