import torch
import torch.nn as nn
import torch.nn.functional as F
class MEGAN(nn.Module):
def __init__(self, in_channels, out_channels, num_frames):
super(MEGAN, self).__init__()
self.num_frames = num_frames
# Graph construction
self.spatial_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
self.temporal_conv = nn.Conv2d(1, 1, kernel_size=(num_frames, 1), stride=1, padding=0, bias=False)
# Graph convolution layers
self.gcn_layers = nn.ModuleList()
for i in range(4):
self.gcn_layers.append(GCN(out_channels, out_channels))
# Memory enhancement module
self.mem_conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
self.mem_conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
self.mem_conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
# Output layers
self.out_conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
self.out_conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True)
self.out_conv3 = nn.Conv2d(out_channels, in_channels * num_frames, kernel_size=3, stride=1, padding=1, bias=True)
def forward(self, x):
# Graph construction
b, t, c, h, w = x.size()
x = x.view(-1, c, h, w)
spatial_x = self.spatial_conv(x)
spatial_x = spatial_x.view(b, t, -1, spatial_x.size(2), spatial_x.size(3))
temporal_x = spatial_x.transpose(1, 2).contiguous().view(b, -1, self.num_frames, spatial_x.size(3), spatial_x.size(4))
temporal_x = temporal_x.mean(1, keepdim=True)
temporal_x = self.temporal_conv(temporal_x)
temporal_x = temporal_x.repeat(1, spatial_x.size(2), 1, 1, 1)
spatial_x = spatial_x.view(-1, out_channels, spatial_x.size(3), spatial_x.size(4))
# Graph convolution layers
for gcn_layer in self.gcn_layers:
spatial_x = gcn_layer(spatial_x, temporal_x)
# Memory enhancement module
mem_x = self.mem_conv1(spatial_x)
mem_x = F.relu(mem_x)
利用图卷积网络(GCN)来学习整个视频之间的时间相关性
于 2023-02-27 10:37:37 首次发布