import torch
import torch.nn as nn
class DFF(nn.Module):
def __init__(self, dim):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_atten = nn.Sequential(
nn.Conv3d(dim * 2 , dim * 2, kernel_size=1, bias=False),
nn.Sigmoid()
)
self.conv_redu =nn.Conv3d(dim * 2 , dim, kernel_size=1, bias=False)
self.conv1 = nn.Conv3d(dim, 1, kernel_size=1, bias=True)
self.conv2 = nn.Conv3d(dim, 1, kernel_size=1, bias=True)
self.nonlin = nn.Sigmoid()
def forward(self, x,skip):
output = torch.cat([x,skip],dim=1)
att = self.conv_atten(self.avg_pool(output))
output = output * att
output = self.conv_redu(output)
att = self.conv1(x) + self.conv2(skip)
att = self.nonlin(att)
output = output * att
return output
if __name__ == '__main__':
x = torch.randn(1, 48, 128,128,128)
skip = torch.randn(1, 48, 128,128,128)
model = DFF(48)
output = model(x,skip)
print(x.shape)
print(output.shape)
# 输出:torch.Size([1, 48, 128, 128, 128])
![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/eec9440caab94e42a327af300b9b389c.png)
DDF代码介绍 创新点
最新推荐文章于 2024-06-14 10:30:15 发布