提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
笔者这里为什么特别要写一个拆ema-vfi的代码功能分析博文,其实原因有二:
第一:笔者为什么这么看好视频帧插,因为兴起的自媒体表明视频会越来越广泛,因此视频处理应用会比以往更广泛。而帧插是一种优秀的视频处理结构,帧插可以通过两帧或多帧帧形成新的一帧,在此基础上我们可以添加很多功能。
第二:这篇文章是transformer在帧插工作上的应用案例,具有典型性。
EMV-VFI 的主要过程在demo_Nx.py 其中 完成两帧的帧插,大概分以下几个大的过程:
1 model = Model(-1)
2 load model
3 model.eval()
4 preprocess
5 model.multiple_inference
一、model = Model(-1)
class Model:
def __init__(self, local_rank):
backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE']
backbonecfg, multiscalecfg = MODEL_CONFIG['MODEL_ARCH']
self.net = multiscaletype(backbonetype(**backbonecfg), **multiscalecfg)
self.name = MODEL_CONFIG['LOGNAME']
self.device()
# train
self.optimG = AdamW(self.net.parameters(), lr=2e-4, weight_decay=1e-4)
self.lap = LapLoss()
if local_rank != -1:
self.net = DDP(self.net, device_ids=[local_rank], output_device=local_rank)
backbonetype 是feature_extractor, 其实也就是 motionformer
multiscaletype 是flow_estimation
class MotionFormer(nn.Module):
def __init__(self, in_chans=3, embed_dims=[32, 64, 128, 256, 512], motion_dims=64, num_heads=[8, 16],
mlp_ratios=[4, 4], qkv_bias=True, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
depths=[2, 2, 2, 6, 2], window_sizes=[11, 11],**kwarg):
super().__init__()
self.depths = depths
self.num_stages = len(embed_dims)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
self.conv_stages = self.num_stages - len(num_heads)
for i in range(self.num_stages):
if i == 0:
block = ConvBlock(in_chans,embed_dims[i],depths[i])
else:
if i < self.conv_stages:
patch_embed = nn.Sequential(
nn.Conv2d(embed_dims[i-1], embed_dims[i], 3,2,1),
nn.PReLU(embed_dims[i])
)
block = ConvBlock(embed_dims[i],embed_dims[i],depths[i])
else:
if i == self.conv_stages:
patch_embed = CrossScalePatchEmbed(embed_dims[:i],
embed_dim=embed_dims[i])
else:
patch_embed = OverlapPatchEmbed(patch_size=3,
stride=2,
in_chans=embed_dims[i - 1],
embed_dim=embed_dims[i])
block = nn.ModuleList([MotionFormerBlock(
dim=embed_dims[i], motion_dim=motion_dims[i], num_heads=num_heads[i-self.conv_stages], window_size=window_sizes[i-self.conv_stages],
shift_size= 0 if (j % 2) == 0 else window_sizes[i-self.conv_stages] // 2,
mlp_ratio=mlp_ratios[i-self.conv_stages], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer)
for j in range(depths[i])])
norm = norm_layer(embed_dims[i])
setattr(self, f"norm{i + 1}", norm)
setattr(self, f"patch_embed{i + 1}", patch_embed)
cur += depths[i]
setattr(self, f"block{i + 1}", block)
self.cor = {}
self.apply(self._init_weights)
这里只是将motion former 的代码进行实例。因为这里构造过程后面的ifoward 过程非常类似。
其中 MotionFormer总共由下图模块构成
二、load model
def load_model(self, name=None, rank=0):
def convert(param):
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k and 'attn_mask' not in k and 'HW' not in k
}
if rank <= 0 :
if name is None:
name = self.name
self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')))
这里初始化了所有的 patch/block/norm。这里有一种写法比较trick ,就是读取模型的名字,将名字作为函数。调试的时候不是很方便。
三、model.eval()
def eval(self):
self.net.eval()
eval 这里其实是段废代码,应该是用了一些框架,将正常的eval过程用false ,做了个空运处理。
四、preprocess
五、model.multiple_inference
img_list = os.listdir(input_path)
img_list.sort(key=lambda x: int(x.split('.')[0])) # sort
count = 0
for i in range(len(img_list)-1):
print(f'=========================Start Generating=========================')
I0 = cv2.imread(input_path + str(i) + '.jpg')
I2 = cv2.imread(input_path + str(i + 1) + '.jpg')
#I0 = cv2.resize(I0,dsize=None,fx=0.25,fy=0.25,interpolation=cv2.INTER_LINEAR)
#I2 = cv2.resize(I2,dsize=None,fx=0.25,fy=0.25,interpolation=cv2.INTER_LINEAR)
I0_ = (torch.tensor(I0.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0)
I2_ = (torch.tensor(I2.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0)
padder = InputPadder(I0_.shape, divisor=32)
I0_, I2_ = padder.pad(I0_, I2_)
images = [I0[:, :, ::-1]]
preds = model.multi_inference(I0_, I2_, TTA=TTA, time_list=[(i+1)*(1./args.n) for i in range(args.n - 1)], fast_TTA=TTA)
for pred in preds:
images.append((padder.unpad(pred).detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1])
images.append(I2[:, :, ::-1])
#mimsave('example/out_Nx.gif', images, fps=args.n)
for i in range(len(images)):
#cv2.imwrite('./memc_result/img{}.jpg'.format(count), images[i][:, :, ::-1])
cv2.imwrite('./memc_result/{}.jpg'.format(str(count).zfill(4)), images[i][:, :, ::-1])
count += 1
print(f'=========================Done=========================')
这里最核心的就是 model.multi_inference
其中 motion former 的 foward 如下:
def forward(self, x1, x2):
B = x1.shape[0]
x = torch.cat([x1, x2], 0)
motion_features = []
appearence_features = []
xs = []
for i in range(self.num_stages):
motion_features.append([])
patch_embed = getattr(self, f"patch_embed{i + 1}",None)
block = getattr(self, f"block{i + 1}",None)
norm = getattr(self, f"norm{i + 1}",None)
if i < self.conv_stages:
if i > 0:
x = patch_embed(x)
x = block(x)
xs.append(x)
else:
if i == self.conv_stages:
x, H, W = patch_embed(xs)
else:
x, H, W = patch_embed(x)
cor = self.get_cor((x.shape[0], H, W), x.device)
for blk in block:
x, x_motion = blk(x, cor, H, W, B)
motion_features[i].append(x_motion.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous())
x = norm(x)
x = x.reshape(2*B, H, W, -1).permute(0, 3, 1, 2).contiguous()
motion_features[i] = torch.cat(motion_features[i], 1)
appearence_features.append(x)
return appearence_features, motion_features
这里我对这个循环进行解释下
当i == 0/1/2/3/4时的过程进行注释
最终 motion former 这里会返回两个 list
持续更新中