Video_based_ReID_RNN

前言

接下来,我们就来看看视频行人重识别训练模型的其中一种temporal aggregation method:RNN。
这是在序列模型训练中常用的一种模型,RNN可以提取到连续图像蕴含的信息,这里使用的是最简单的RNN结构。
目前这种方式的试验结果不如其他几种,如B部分:
在这里插入图片描述

模型输入

输入和之前的相同 差别只在经过的网络:

  • imgs
    • imgs.size() = [b,s,c,h,w]
    • 在训练级中 b为batch通常设置为32,seq_len设置为4,c为通道数为3,h图片高,w图片宽

模型初始化参数

        model = models.init_model(name=args.arch, num_classes=dataset.num_train_pids, loss={'xent', 'htri'})
  • name 使用的模型名称
  • dataset.num_train_pids 分类时的分类数
  • loss xent=交叉熵损失 htri=Tripletloss

模型实现

class ResNet50RNN(nn.Module):
    def __init__(self, num_classes, loss={'xent'}, **kwargs):
        super(ResNet50RNN, self).__init__()
        self.loss = loss
        resnet50 = torchvision.models.resnet50(pretrained=True)
        self.base = nn.Sequential(*list(resnet50.children())[:-2])
        self.hidden_dim = 512
        self.feat_dim = 2048
        self.classifier = nn.Linear(self.hidden_dim, num_classes)
        #                   输入特征维数2048              LSTM中隐层的维度              循环神经网络的层数
        #输入的数据shape=(batch_size,seq_length,embedding_dim),而batch_first默认是False,所以我们的输入数据最好送进LSTM之前将batch_size与seq_length这两个维度调换?
        self.lstm = nn.LSTM(input_size=self.feat_dim, hidden_size=self.hidden_dim, num_layers=1, batch_first=True)

	# x = [32,4,3,224,112] [b,s,c,h,w]
    def forward(self, x):
        # b=32
        b = x.size(0)
        # t= 4
        t = x.size(1)
        # x = [128,3,224,112]
        x = x.view(b*t,x.size(2), x.size(3), x.size(4))
		# x = [128,2048,7,4]
        x = self.base(x)
		# [128,2048,1,1]
        x = F.avg_pool2d(x, x.size()[2:])
        x = x.view(b,t,-1)
        # x = [32,2048,4]
        # 使用RNN直接获取特征
        # output = [32,4,512]?
        output, (h_n, c_n) = self.lstm(x)
        # output = [32,512,4]
        output = output.permute(0, 2, 1)
        # f = [32,512]
        f = F.avg_pool1d(output, t)
        f = f.view(b, self.hidden_dim)
        
        if not self.training:
            return f
        y = self.classifier(f)

        if self.loss == {'xent'}:
            return y
        elif self.loss == {'xent', 'htri'}:
            return y, f
        elif self.loss == {'cent'}:
            return y, f
        else:
            raise KeyError("Unsupported loss: {}".format(self.loss))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值