class Process_conf(nn.Module):
def __init__(self, in_channels):
super(Process_conf, self).__init__()
self.down_search = nn.Sequential(
RepBlock(256, 256, kernel_size=3, padding=1),
nn.Conv2d(256, 256, kernel_size=(7, 7)),
RepBlock(256, 256, kernel_size=3, padding=1))
self.down_conf = nn.Sequential(
RepBlock(768, in_channels, kernel_size=3, padding=1),
RepBlock(in_channels, in_channels, kernel_size=3, padding=1),
RepBlock(in_channels, in_channels, kernel_size=3, padding=1),
RepBlock(in_channels, 4, kernel_size=3, padding=1))
self.fc_blocks = nn.Sequential(
nn.Linear(in_features=2500, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=1))
# initialization
for modules in [self.down_conf, self.fc_blocks, self.down_search]:
for l in modules.modules():
if isinstance(l, nn.Conv2d) or isinstance(l, nn.Linear):
torch.nn.init.normal_(l.weight, std=0.01)
def forward(self, f_, search_): # x (b,512,25,25) search(b,256,31,31)
search_ = self.down_search(search_) # search(b,256,25,25)
x = torch.cat((f_, search_), dim=1) # x(b,768,25,25)
conf = self.down_conf(x) # conf (b,4,25,25)
conf = conf.flatten(1) # conf (b,2500)
conf = self.fc_blocks(conf).sigmoid().squeeze() # conf tensor(0.9998, device'cuda0')
return conf
jyfvgbv jyfk
最新推荐文章于 2024-07-15 21:39:15 发布