import torch
import torch.nn as nn
class AttentionModule(nn.Module):
def __init__(self, input_channels):
super(AttentionModule, self).__init__()
self.conv = nn.Conv2d(input_channels, 1, kernel_size=1) # 用于学习注意力权重的卷积层
self.sigmoid = nn.Sigmoid() # Sigmoid函数用于将注意力权重限制在0到1之间
def forward(self, x):
# 将输入x的尺寸从[batch_size, input_channels, height, width]变为[batch_size, input_channels, height, width]
# x = x.unsqueeze(2) # 在第3维上增加一个维度
attention_weights = self.conv(x) # 应用卷积层获取注意力权重,尺寸为[batch_size, 1, height, width]
attention_weights = self.sigmoid(attention_weights) # 使用Sigmoid函数将注意力权重限制在0到1之间
attended_features = x * attention_weights # 将注意力权重应用到输入特征上
return attended_features # 去除增加的维度
# 创建一个尺寸为[16, 32, 1, 5]的随机输入张量
input_tensor = torch.randn(16, 32, 8, 3)
# 创建注意力模块实例
attention_module = AttentionModule(input_channels=32)
# 应用注意力模块到输入张量上
output_tensor = attention_module(input_tensor)
# 打印输出张量的尺寸
print(output_tensor.size())
python自注意力模块
最新推荐文章于 2024-03-20 12:36:55 发布