tracker_utile

import torch.nn as nn
import torch.nn.functional as F
import torch as t
import math
from pysot.models.utile_tctrack.trantime import Transformertime

    
class TCT(nn.Module):
    
    def __init__(self,cfg):
        super(TCT, self).__init__()



        self.conv1 = nn.Sequential(
            nn.Conv2d(256, 192, kernel_size=3, bias=False, stride=2,padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            )
        self.conv2 = nn.Sequential(
            nn.Conv2d(256, 192, kernel_size=3, bias=False, stride=2,padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            )
        
        
        channel=192

        self.convloc = nn.Sequential(
                nn.Conv2d(channel, channel,  kernel_size=3, stride=1,padding=1),
                nn.BatchNorm2d(channel),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel,  kernel_size=3, stride=1,padding=1),
                nn.BatchNorm2d(channel),
                nn.ReLU(inplace=True),                
                nn.Conv2d(channel, channel,  kernel_size=3, stride=1,padding=1),
                nn.BatchNorm2d(channel),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel,  kernel_size=3, stride=1,padding=1),
                nn.BatchNorm2d(channel),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, 4,  kernel_size=3, stride=1,padding=1),
                )
        
        self.convcls = nn.Sequential(
                nn.Conv2d(channel, channel,  kernel_size=3, stride=1,padding=1),
                nn.BatchNorm2d(channel),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel,  kernel_size=3, stride=1,padding=1),
                nn.BatchNorm2d(channel),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel,  kernel_size=3, stride=1,padding=1),
                nn.BatchNorm2d(channel),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel,  kernel_size=3, stride=1,padding=1),
                nn.BatchNorm2d(channel),
                nn.ReLU(inplace=True),
                )

        self.transformer = Transformertime(channel, 6, 1, 2)
        
        self.cls1=nn.Conv2d(channel, 2,  kernel_size=3, stride=1,padding=1)
        self.cls2=nn.Conv2d(channel, 1,  kernel_size=3, stride=1,padding=1)
        for modules in [self.conv1,self.conv2,self.convloc,self.convcls,self.cls1,self.cls2]:
            for l in modules.modules():
                if isinstance(l, nn.Conv2d):
                    t.nn.init.normal_(l.weight, std=0.01)

    def reset_parameters(self):
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform_(self.col_embed.weight)
        
    def xcorr_depthwise(self,x, kernel):
        """depthwise cross correlation
        """
        batch = kernel.size(0)
        channel = kernel.size(1)
        x = x.reshape(1, batch*channel, x.size(2), x.size(3))
        kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3))
        out = F.conv2d(x, kernel, groups=batch*channel)
        out = out.view(batch, channel, out.size(2), out.size(3))
        return out
    
    def forward(self,x,z,px):

        
        ppres=self.conv1(self.xcorr_depthwise(px[0],z))
        
        
        for i in range(len(px)):

                res3=self.conv2(self.xcorr_depthwise(px[i],z))
            
                
                b,c,w,h=res3.size()
                memory=self.transformer.encoder((res3).view(b,c,-1).permute(2, 0, 1),\
                                     (ppres).view(b,c,-1).permute(2, 0, 1))
        
                ppres=memory.permute(1,2,0).view(b,c,w,h)
                
                

        res3=self.conv2(self.xcorr_depthwise(x,z))
        _,res=self.transformer((res3).view(b,c,-1).permute(2, 0, 1),\
                                     (ppres).view(b,c,-1).permute(2, 0, 1),\
                                     res3.view(b,c,-1).permute(2, 0, 1))
                
        res=res.permute(1,2,0).view(b,c,w,h)
        
        loc=self.convloc(res)
        acls=self.convcls(res)

        cls1=self.cls1(acls)
        cls2=self.cls2(acls)

        return loc,cls1,cls2






  • 9
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值