【Masked Video Distillation蒸馏损失函数】

在这里插入图片描述

Masked Video Distillation蒸馏损失函数伪码实现:

# f: student encoder  # encoder
# g_img: decoder for reconstructing spatial features  # decoder for image
# g_vid: decoder for reconstructing spatial-temporal  # decoder for video
features
# t_m: learnable mask tokens # 掩码token
# h_img: image teacher model  #image teacher
# h_vid: video teacher model  #video teacher
for x, m in loader: # x: video data, m: mask
x_pe = patch_emb(x) # patch embedding of input #patch embedding
x_vis = mask_select(x_pe, 1 - m) # masking tokens # 可见token
q_vis = f(x_vis) # visible local patch features  # 编码结果
# reconstruction of target features
p_img = g_img(concat(q_vis, t_m)) # 重建image结果
p_vid = g_vid(concat(q_vis, t_m)) # 重建video结果
# compute target features with teacher models
k_img = h_img(x) # target spatial features  # image teacher 预测结果
k_vid = h_vid(x) # target spatial-temporal features # video teacher 预测结果
# compute reconstruction loss
loss_img = smooth_L1_loss(p_img ? m, k_img ? m) # image loss
loss_vid = smooth_L1_loss(p_vid ? m, k_vid ? m) # video loss
loss = λ1 * loss_img + λ2 * loss_vid #总体损失函数
loss.backward()
optimizer.step() # optimizer update```

Masked Video Distillation蒸馏损失函数源码实现:

        with torch.cuda.amp.autocast():
            output_features, output_video_features = model(videos, bool_masked_pos)
            with torch.no_grad():
                image_teacher_model.eval()  #训练image teacher
                if time_stride_loss:
                    teacher_features = image_teacher_model(  #得到image teacher的预测结果
                        rearrange(videos_for_teacher[:, :, ::tubelet_size, :, :], 'b c t h w -> (b t) c h w'),
                    )
                    teacher_features = rearrange(teacher_features, '(b t) l c -> b (t l) c', t=T//tubelet_size)
                else:
                    teacher_features = image_teacher_model(
                        rearrange(videos_for_teacher, 'b c t h w -> (b t) c h w'),
                    )
                    teacher_features = rearrange(teacher_features, '(b t d) l c -> b (t l) (d c)', t=T//tubelet_size, d=tubelet_size)
                if norm_feature:
                    teacher_features = LN_img(teacher_features)

                video_teacher_model.eval()  # 训练video teacher
                videos_for_video_teacher = videos if args.video_teacher_input_size == args.input_size \
                    else videos_for_teacher

                video_teacher_features = video_teacher_model(videos_for_video_teacher)#得到video teacher的预测结果
                if norm_feature:
                    video_teacher_features = LN_vid(video_teacher_features)

            B, _, D = output_features.shape
            loss_img_feat = loss_func_img_feat(#image teacher 的 loss
                input=output_features,
                target=teacher_features[bool_masked_pos].reshape(B, -1, D)
            )
            loss_value_img_feat = loss_img_feat.item()

            B, _, D = output_video_features.shape
            loss_vid_feat = loss_func_vid_feat(#video teacher 的 loss
                input=output_video_features,
                target=video_teacher_features[bool_masked_pos].reshape(B, -1, D)
            )
            loss_value_vid_feat = loss_vid_feat.item()

            loss = image_loss_weight * loss_img_feat + video_loss_weight * loss_vid_feat#总的损失函数
            ```

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值