Timesformer的代码实现

整体框架

代码主要由两部分组成,特征提取(forward_features)和分类头(head)两部分。

输入的x的维度为[1,3,8,224,224]

特征提取

    def forward_features(self, x):
        B = x.shape[0]
        
 1###################################################
        
        #对x进行分块,并对每个patch进行embed输出一个向量来表示这个patch,
        #输入x:[1,3,8,224,224]
        #输出x:[8,196,768]
        x, T, W = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        #将cls_token和path拼接起来
        x = torch.cat((cls_tokens, x), dim=1)

        ## resizing the positional embeddings in case they don't match the input at inference
        if x.size(1) != self.pos_embed.size(1):
            pos_embed = self.pos_embed
            cls_pos_embed = pos_embed[0,0,:].unsqueeze(0).unsqueeze(1)
            other_pos_embed = pos_embed[0,1:,:].unsqueeze(0).transpose(1, 2)
            P = int(other_pos_embed.size(2) ** 0.5)
            H = x.size(1) // W
            other_pos_embed = other_pos_embed.reshape(1, x.size(2), P, P)
            new_pos_embed = F.interpolate(other_pos_embed, size=(H, W), mode='nearest')
            new_pos_embed = new_pos_embed.flatten(2)
            new_pos_embed = new_pos_embed.transpose(1, 2)
            new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
            x = x + new_pos_embed
        else:
        	#加入位置编码
        	#x的维度[(b t) n m],即8,197, 768,
        	#批次(1)*时间(总共采样8张图片),patch数(196+1个cls_token),
        	#通道数(768=16*16*3)
        	#pos_embed的维度为[1,197,768]
        	#两者相加,由于广播机制,pos_embed的维度会自动扩展为[8,197,768]
        	#这样可以保证x[(b t) n m]在第0维度相等,而只是不同patch(位置)加的值不一样而已
        	#[8,197,768]+[1,197,768]=[8,197,768]+[8,197,768]
            x = x + self.pos_embed
        x = self.pos_drop(x)

 2###################################################
		#加入时间编码,类似于之前的位置编码,值得注意的是,在进行时间编码时,需要进行维度变换
		#x为(b t) n m -> (b n) t m
		#去除cls_token [8,197,768]->[8,197,768]
		#[8,196,768]->[196,8,768]
        ## Time Embeddings
        if self.attention_type != 'space_only':
            cls_tokens = x[:B, 0, :].unsqueeze(1)
            #除了cls_token外,所有的都要加上时间编码
            x = x[:,1:]
            x = rearrange(x, '(b t) n m -> (b n) t m',b=B,t=T)
            ## Resizing time embeddings in case they don't match
            if T != self.time_embed.size(1):
                time_embed = self.time_embed.transpose(1, 2)
                new_time_embed = F.interpolate(time_embed, size=(T), mode='nearest')
                new_time_embed = new_time_embed.transpose(1, 2)
                x = x + new_time_embed
            else:
            	#time_embed的维度为[1,8,768],由于广播机制,使得其在不同时间上加上了位置编码
            	#[196,8,768]+[1,8,768]=[196,8,768]+[196,8,768]
                x = x + self.time_embed
            x = self.time_drop(x)
            x = rearrange(x, '(b n) t m -> b (n t) m',b=B,t=T)
            x = torch.cat((cls_tokens, x), dim=1)

 3###################################################
        ## Attention blocks
        for blk in self.blocks:
            x = blk(x, B, T, W)

        ### Predictions for space-only baseline
        if self.attention_type == 'space_only':
            x = rearrange(x, '(b t) n m -> b t n m',b=B,t=T)
            x = torch.mean(x, 1) # averaging predictions for every frame

        x = self.norm(x)
        return x[:, 0]

一、这一部分就是下图框起来的部分
在这里插入图片描述
二、时间编码
根据注意力类型判断是否需要加时间编码,这个时间编码主要是为了后续进行Time Attention;
就像加入位置编码后,用以Space Attention一样。

三、注意力块
在这里插入图片描述

ModuleList(
  (0): Block(
    (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (attn): Attention(
      (qkv): Linear(in_features=768, out_features=2304, bias=True)
      (proj): Linear(in_features=768, out_features=768, bias=True)
      (proj_drop): Dropout(p=0.0, inplace=False)
      (attn_drop): Dropout(p=0.0, inplace=False)
    )
    (temporal_norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (temporal_attn): Attention(
      (qkv): Linear(in_features=768, out_features=2304, bias=True)
      (proj): Linear(in_features=768, out_features=768, bias=True)
      (proj_drop): Dropout(p=0.0, inplace=False)
      (attn_drop): Dropout(p=0.0, inplace=False)
    )
    (temporal_fc): Linear(in_features=768, out_features=768, bias=True)
    (drop_path): Identity()
    (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
    (mlp): Mlp(
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (act): GELU(approximate=none)
      (fc2): Linear(in_features=3072, out_features=768, bias=True)
      (drop): Dropout(p=0.0, inplace=False)
    )
  )
 ***重复depth次,根据Timesformer初始化传入的depth值来确定*
)

下面就是一个block所执行的代码

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention_type='divided_space_time'):
        super().__init__()
        self.attention_type = attention_type
        assert(attention_type in ['divided_space_time', 'space_only','joint_space_time'])

        self.norm1 = norm_layer(dim)
        self.attn = Attention(
           dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        ## Temporal Attention Parameters
        if self.attention_type == 'divided_space_time':
            self.temporal_norm1 = norm_layer(dim)
            self.temporal_attn = Attention(
              dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
            self.temporal_fc = nn.Linear(dim, dim)

        ## drop path
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)


    def forward(self, x, B, T, W):
        num_spatial_tokens = (x.size(1) - 1) // T
        H = num_spatial_tokens // W
		
		#判断注意力类型
        if self.attention_type in ['space_only', 'joint_space_time']:
            x = x + self.drop_path(self.attn(self.norm1(x)))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
            return x
        elif self.attention_type == 'divided_space_time':
            ## Temporal
            #cls_token不参与attention计算
            xt = x[:,1:,:]
            xt = rearrange(xt, 'b (h w t) m -> (b h w) t m',b=B,h=H,w=W,t=T)
            #xt【1,1568,768】>>>>>>【196,18,768】
            #norm>>>attention>>>DropPath
            res_temporal = self.drop_path(self.temporal_attn(self.temporal_norm1(xt)))
            res_temporal = rearrange(res_temporal, '(b h w) t m -> b (h w t) m',b=B,h=H,w=W,t=T)
            res_temporal = self.temporal_fc(res_temporal)
            #res_temporal【1,1568,768】,时间自注意力
            #xt【1,1568,768】,xt不包含cls_token
            xt = x[:,1:,:] + res_temporal

            #cls_token沿着时间维度上复制T次,然后变换成(b t) m,进行spatial attention
            init_cls_token = x[:,0,:].unsqueeze(1)
            cls_token = init_cls_token.repeat(1, T, 1)
            cls_token = rearrange(cls_token, 'b t m -> (b t) m',b=B,t=T).unsqueeze(1)

			## Spatial
            xs = xt
            xs = rearrange(xs, 'b (h w t) m -> (b t) (h w) m',b=B,h=H,w=W,t=T)
            #xs【8,196,768】
            #cls_token【8,1,768】
            #xs【8,197,768】
            #输入attention的维度也为8,197,768
            xs = torch.cat((cls_token, xs), 1)
            #res_spatial 【8,197,768】
            res_spatial = self.drop_path(self.attn(self.norm1(xs)))

            ### Taking care of CLS token
            cls_token = res_spatial[:,0,:]
            #cls_token 【8,768】>>>>>>>>【1,8,768】>>>mean>>>>>【1,1,768】
            cls_token = rearrange(cls_token, '(b t) m -> b t m',b=B,t=T)
            cls_token = torch.mean(cls_token,1,True) ## averaging for every frame
            res_spatial = res_spatial[:,1:,:]
            #res_spatial【8,196,768】>>>>>>>>【1,1568,768】
            res_spatial = rearrange(res_spatial, '(b t) (h w) m -> b (h w t) m',b=B,h=H,w=W,t=T)
            res = res_spatial
            x = xt

            ## Mlp,全连接操作
            x = torch.cat((init_cls_token, x), 1) + torch.cat((cls_token, res), 1)
            x = x + self.drop_path(self.mlp(self.norm2(x)))
            return x

经过block之后的x的维度大小为x【1,1569,768】,最后返回为x[:,0]为【1,768】

可以参考学习详解Transformer中Self-Attention以及Multi-Head Attention
用到的Attention实现:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., with_qkv=True):
        super().__init__()
        #多头数目
        self.num_heads = num_heads
        #每一头对应的维度
        head_dim = dim // num_heads
        #self-attention 公式的根号d
        self.scale = qk_scale or head_dim ** -0.5
        self.with_qkv = with_qkv
        #生成q,k,v
        if self.with_qkv:
           self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
           self.proj = nn.Linear(dim, dim)
           self.proj_drop = nn.Dropout(proj_drop)
        self.attn_drop = nn.Dropout(attn_drop)

    def forward(self, x):
        B, N, C = x.shape
        if self.with_qkv:
           qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
           q, k, v = qkv[0], qkv[1], qkv[2]
        else:
           qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
           q, k, v  = qkv, qkv, qkv

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        if self.with_qkv:
        	#这里实现的是concate和乘以Wo矩阵,用一个全连接层代替,因为是以矩阵形式表示的,所以不需要concate步骤,即原来就是一个整体,即矩阵可以实现并行计算
           x = self.proj(x)
           x = self.proj_drop(x)
        return x

行为分类

最后返回的x[1,768为]经过分类头
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
,求得每一个类别的概率,取一个最大值,即实现了分类。

  • 9
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值