视频帧插学习(三):ema-vfi代码拆解和分析

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

笔者这里为什么特别要写一个拆ema-vfi的代码功能分析博文,其实原因有二:

第一:笔者为什么这么看好视频帧插,因为兴起的自媒体表明视频会越来越广泛,因此视频处理应用会比以往更广泛。而帧插是一种优秀的视频处理结构,帧插可以通过两帧或多帧帧形成新的一帧,在此基础上我们可以添加很多功能。

第二:这篇文章是transformer在帧插工作上的应用案例,具有典型性。


EMV-VFI 的主要过程在demo_Nx.py 其中 完成两帧的帧插,大概分以下几个大的过程:
1 model = Model(-1)
2 load model
3 model.eval()
4 preprocess
5 model.multiple_inference

一、model = Model(-1)

class Model:
    def __init__(self, local_rank):
        backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE']
        backbonecfg, multiscalecfg = MODEL_CONFIG['MODEL_ARCH']
        self.net = multiscaletype(backbonetype(**backbonecfg), **multiscalecfg)
        self.name = MODEL_CONFIG['LOGNAME']
        self.device()

        # train
        self.optimG = AdamW(self.net.parameters(), lr=2e-4, weight_decay=1e-4)
        self.lap = LapLoss()
        if local_rank != -1:
            self.net = DDP(self.net, device_ids=[local_rank], output_device=local_rank)

在这里插入图片描述
backbonetype 是feature_extractor, 其实也就是 motionformer
multiscaletype 是flow_estimation
在这里插入图片描述

class MotionFormer(nn.Module):
    def __init__(self, in_chans=3, embed_dims=[32, 64, 128, 256, 512], motion_dims=64, num_heads=[8, 16], 
                 mlp_ratios=[4, 4], qkv_bias=True, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[2, 2, 2, 6, 2], window_sizes=[11, 11],**kwarg):
        super().__init__()
        self.depths = depths
        self.num_stages = len(embed_dims)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        cur = 0
        self.conv_stages = self.num_stages - len(num_heads)

        for i in range(self.num_stages):
            if i == 0:
                block = ConvBlock(in_chans,embed_dims[i],depths[i])
            else:
                if i < self.conv_stages:
                    patch_embed = nn.Sequential(
                        nn.Conv2d(embed_dims[i-1], embed_dims[i], 3,2,1),
                        nn.PReLU(embed_dims[i])
                    )
                    block = ConvBlock(embed_dims[i],embed_dims[i],depths[i])
                else:
                    if i == self.conv_stages:
                        patch_embed = CrossScalePatchEmbed(embed_dims[:i],
                                                        embed_dim=embed_dims[i])
                    else:
                        patch_embed = OverlapPatchEmbed(patch_size=3,
                                                        stride=2,
                                                        in_chans=embed_dims[i - 1],
                                                        embed_dim=embed_dims[i])

                    block = nn.ModuleList([MotionFormerBlock(
                        dim=embed_dims[i], motion_dim=motion_dims[i], num_heads=num_heads[i-self.conv_stages], window_size=window_sizes[i-self.conv_stages], 
                        shift_size= 0 if (j % 2) == 0 else window_sizes[i-self.conv_stages] // 2,
                        mlp_ratio=mlp_ratios[i-self.conv_stages], qkv_bias=qkv_bias, qk_scale=qk_scale,
                        drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer)
                        for j in range(depths[i])])

                    norm = norm_layer(embed_dims[i])
                    setattr(self, f"norm{i + 1}", norm)
                setattr(self, f"patch_embed{i + 1}", patch_embed)
            cur += depths[i]

            setattr(self, f"block{i + 1}", block)

        self.cor = {}

        self.apply(self._init_weights)

这里只是将motion former 的代码进行实例。因为这里构造过程后面的ifoward 过程非常类似。
其中 MotionFormer总共由下图模块构成
在这里插入图片描述

二、load model

    def load_model(self, name=None, rank=0):
        def convert(param):
            return {
            k.replace("module.", ""): v
                for k, v in param.items()
                if "module." in k and 'attn_mask' not in k and 'HW' not in k
            }
        if rank <= 0 :
            if name is None:
                name = self.name
            self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')))

这里初始化了所有的 patch/block/norm。这里有一种写法比较trick ,就是读取模型的名字,将名字作为函数。调试的时候不是很方便。
在这里插入图片描述

三、model.eval()

    def eval(self):
        self.net.eval()

eval 这里其实是段废代码,应该是用了一些框架,将正常的eval过程用false ,做了个空运处理。

四、preprocess

在这里插入图片描述

五、model.multiple_inference

    img_list = os.listdir(input_path)
    img_list.sort(key=lambda x: int(x.split('.')[0]))  # sort

    count = 0
    for i in range(len(img_list)-1):
        print(f'=========================Start Generating=========================')

        I0 = cv2.imread(input_path + str(i) + '.jpg')
        I2 = cv2.imread(input_path + str(i + 1) + '.jpg')

        #I0 = cv2.resize(I0,dsize=None,fx=0.25,fy=0.25,interpolation=cv2.INTER_LINEAR)
        #I2 = cv2.resize(I2,dsize=None,fx=0.25,fy=0.25,interpolation=cv2.INTER_LINEAR)


        I0_ = (torch.tensor(I0.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0)
        I2_ = (torch.tensor(I2.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0)

        padder = InputPadder(I0_.shape, divisor=32)
        I0_, I2_ = padder.pad(I0_, I2_)

        images = [I0[:, :, ::-1]]
        preds = model.multi_inference(I0_, I2_, TTA=TTA, time_list=[(i+1)*(1./args.n) for i in range(args.n - 1)], fast_TTA=TTA)
        for pred in preds:
            images.append((padder.unpad(pred).detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1])
        images.append(I2[:, :, ::-1])


        #mimsave('example/out_Nx.gif', images, fps=args.n)
        for i in range(len(images)):
            #cv2.imwrite('./memc_result/img{}.jpg'.format(count), images[i][:, :, ::-1])
            cv2.imwrite('./memc_result/{}.jpg'.format(str(count).zfill(4)), images[i][:, :, ::-1])
            count += 1


        print(f'=========================Done=========================')

这里最核心的就是 model.multi_inference
在这里插入图片描述
其中 motion former 的 foward 如下:

    def forward(self, x1, x2):
        B = x1.shape[0] 
        x = torch.cat([x1, x2], 0)
        motion_features = []
        appearence_features = []
        xs = []
        for i in range(self.num_stages):
            motion_features.append([])
            patch_embed = getattr(self, f"patch_embed{i + 1}",None)
            block = getattr(self, f"block{i + 1}",None)
            norm = getattr(self, f"norm{i + 1}",None)
            if i < self.conv_stages:
                if i > 0:
                    x = patch_embed(x)
                x = block(x)
                xs.append(x)
            else:
                if i == self.conv_stages:
                    x, H, W = patch_embed(xs)
                else:
                    x, H, W = patch_embed(x)
                cor = self.get_cor((x.shape[0], H, W), x.device)
                for blk in block:
                    x, x_motion = blk(x, cor, H, W, B)
                    motion_features[i].append(x_motion.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous())
                x = norm(x)
                x = x.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous()
                motion_features[i] = torch.cat(motion_features[i], 1)
            appearence_features.append(x)
        return appearence_features, motion_features

这里我对这个循环进行解释下
当i == 0/1/2/3/4时的过程进行注释
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
最终 motion former 这里会返回两个 list
在这里插入图片描述

持续更新中

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值