predictive-filter-flow

predictive-filter-flow

code: https://github.com/aimerykong/predictive-filter-flow

Image Reconstruction with Predictive Filter Flow
Multigrid Predictive Filter Flow
以上2篇论文

1.在第一篇中对图像进行处理,主要在三个任务

deblur, dejpeg, super-res

其中上边是filter的主成分top10

在这里插入图片描述

2.训练过程

作者源码中有deblur的代码

对于dejpeg也是一样的过程
那么超分呢?其实也是一样,因为输入先cubic放大再输入作者的网络

2.1 网络结构

我感觉这个应该不是关键,可以自行设计
网络是输入 noise image (n,c,h,w), 输出 filter (n,17x17,h,w) 其中17是filter的领域大小,就是为每个pixel生成一个17x17的滤波器

2.2 loss

就是一个重建损失
对image1 滤波后 与 image2建立一个损失

class LossOrderedPairReconstruction(nn.Module):
    def __init__(self, device='cpu', filterSize=11):
        super(LossOrderedPairReconstruction, self).__init__()
        self.device = device
        self.filterSize = filterSize        
        self.filterSize2Channel = self.filterSize**2
        self.reconstructImage = 0
        
    def forward(self, image1, image2, filters_img1_to_img2):
        N,C,H,W = image1.size()
        self.reconstructImage = self.rgbImageFilterFlow(image1, filters_img1_to_img2)
        diff = self.reconstructImage - image2               
        diff = torch.abs(diff)       
        totloss = torch.sum(torch.sum(torch.sum(torch.sum(diff))))        
        return totloss/(N*C*H*W)
    
    # 这个就是应用滤波器的过程。
    # 每个通道单独处理
    # 利用unflod展开,然后邻域与 filter相乘相加。
    # filter是经过softmax的,所有sum(filter)=1
    def rgbImageFilterFlow(self, img, filters):                
        inputChannelSize = 1
        outputChannelSize = 1
        N = img.size(0)
        paddingFunc = nn.ZeroPad2d(int(self.filterSize/2))
        img = paddingFunc(img)        
        imgSize = [img.size(2),img.size(3)]
        
        out_R = F.unfold(img[:,0,:,:].unsqueeze(1), (self.filterSize, self.filterSize))
        out_R = out_R.view(N, out_R.size(1), imgSize[0]-self.filterSize+1, imgSize[1]-self.filterSize+1)    
        #out_R = paddingFunc(out_R)
        out_R = torch.mul(out_R, filters)
        out_R = torch.sum(out_R, dim=1).unsqueeze(1)

        out_G = F.unfold(img[:,1,:,:].unsqueeze(1), (self.filterSize, self.filterSize))
        out_G = out_G.view(N, out_G.size(1), imgSize[0]-self.filterSize+1, imgSize[1]-self.filterSize+1)
        #out_G = paddingFunc(out_G)
        out_G = torch.mul(out_G, filters)
        out_G = torch.sum(out_G, dim=1).unsqueeze(1)

        out_B = F.unfold(img[:,2,:,:].unsqueeze(1), (self.filterSize, self.filterSize))
        out_B = out_B.view(N, out_B.size(1), imgSize[0]-self.filterSize+1, imgSize[1]-self.filterSize+1)    
        #out_B = paddingFunc(out_B)
        out_B = torch.mul(out_B, filters)
        out_B = torch.sum(out_B, dim=1).unsqueeze(1)
        return torch.cat([out_R, out_G, out_B], 1)

2.3 可视化

上面介绍了网络和loss,其实就是那么简单,没什么复杂难懂的,很多其他论文中也有相关介绍,比如 kernel predice net等论文.

本文主要介绍了可视化的方法理解filter,将其转换为flow, 有兴趣的可以查看官方源代码

我这里写了一个更简洁的写法,也是将filter转换为flow,和作者得到的flow可视化颜色好像有点差异

下面2个图分别是作者的可视化 filter flow和 我的
在这里插入图片描述

在这里插入图片描述

附上我的代码:

import cv2
import numpy as np
from matplotlib import pyplot as plt
from torchvision.utils import flow_to_image
import torch

import time


def flow_to_image_torch(flow):
    flow = torch.from_numpy(np.transpose(flow, [2, 0, 1]))
    flow_im = flow_to_image(flow)
    img = np.transpose(flow_im.numpy(), [1, 2, 0])
    print(img.shape, img.dtype)
    return img
    
def filter2flow(filter, filter_size=11):
    """

    :param filter: n, filter_size * filter_size, h, w, n need to be 1
    :param filter_size:
    :return:
    """
    n, c, h, w = filter.shape
    assert c == filter_size * filter_size
    filter = filter.view(-1, filter_size * filter_size, h, w)[0]
    filter = filter.permute(1, 2, 0).cpu().numpy() # h,w,c

    x = np.arange(-(filter_size // 2), filter_size // 2 + 1)
    y = np.arange(-(filter_size // 2), filter_size // 2 + 1)
    xx, yy = np.meshgrid(x, y)
    # print(xx, xx.shape)
    # print(yy)

    xx = xx.reshape(1, 1, -1)
    yy = xx.reshape(1, 1, -1)
    u = np.sum(filter * xx, axis=-1)
    v = np.sum(filter * yy, axis=-1)
    #print(u.shape, v.shape)

    flow = np.dstack((u, v))
    #print(flow.shape)
    return flow

filter = torch.from_numpy(filterFlowMap.astype(np.float32))
print(filter.sum(axis=0))
print(filter.shape)
filter = filter[None, ...]
flow = filter2flow(filter, 17).astype(np.float32)
print(flow.shape, flow.min(), flow.max())
flowimg = flow_to_image_torch(flow)
plt.figure()
plt.imshow(flowimg)
plt.show()

3.Multigrid Predictive Filter Flow for Unsupervised Learning on Videos

这一篇更有意思一些,介绍了无监督方法,以及视频上训练方法

3.1 net

在这里插入图片描述

还是有点特点的。 首先求2个图像的 feature(n,128,h,w)
然后concat后,输出filter(n,11x11,h,w), 这里11是filter size

class SiamesePixelEmbed(nn.Module):
    def __init__(self, emb_dimension=32, filterSize=11, device='cpu', pretrained=False):
        super(SiamesePixelEmbed, self).__init__()
        self.device = device
        self.emb_dimension = emb_dimension  
        self.PEMbase = MultigridPFF_tiny(emb_dimension=self.emb_dimension) 
        
        self.rawEmbFeature1 = 0
        self.rawEmbFeature2 = 0        
        self.embFeature1_to_2 = 0
        self.embFeature1_to_2 = 0
        self.filterSize = filterSize
        self.filterSize2Channel = self.filterSize**2
                
        self.ordered_embedding = nn.Sequential(            
            nn.Conv2d(self.emb_dimension*2, self.filterSize2Channel, kernel_size=3, 
                      dilation=1, padding=1, bias=False),
            nn.BatchNorm2d(self.filterSize2Channel, momentum=0.001),     
            nn.ReLU(True),
            nn.Conv2d(self.filterSize2Channel, self.filterSize2Channel, kernel_size=3, 
                      dilation=1, padding=1, bias=False),
            nn.BatchNorm2d(self.filterSize2Channel, momentum=0.001),     
            nn.ReLU(True),
            nn.Conv2d(self.filterSize2Channel, self.filterSize2Channel, kernel_size=3, 
                      padding=1, bias=False),
            nn.BatchNorm2d(self.filterSize2Channel, momentum=0.001),  
            nn.ReLU(True),          
            nn.Conv2d(self.filterSize2Channel, self.filterSize2Channel, kernel_size=3,
                      padding=1, bias=True)
        )
        
        
    def forward(self, inputs1, inputs2): 
        self.rawEmbFeature1 = self.PEMbase.forward(inputs1) # n,128,h,w
        self.rawEmbFeature2 = self.PEMbase.forward(inputs2)
        
        img1_to_img2 = torch.cat([self.rawEmbFeature1, self.rawEmbFeature2], 1)
        img2_to_img1 = torch.cat([self.rawEmbFeature2, self.rawEmbFeature1], 1)
        
        self.embFeature1_to_2 = self.ordered_embedding(img1_to_img2)        
        self.embFeature1_to_2 = F.softmax(self.embFeature1_to_2, 1)
        
        self.embFeature2_to_1 = self.ordered_embedding(img2_to_img1)
        self.embFeature2_to_1 = F.softmax(self.embFeature2_to_1, 1)
        
        return self.embFeature2_to_1, self.embFeature1_to_2    # n,11x11,h,w

3.2 data input

首先图像crop到512

然后对inputa, inputb都 缩放2,4,,8,16,32倍五个尺度,缩放可以用opencv resize,或者其他的都可以的,作者源码用 transforms.Resize
在这里插入图片描述

3.3 loss

本文是无监督,loss函数还是挺多的。超参数也不少。

这里介绍几个:

1)filterloss: 就是利用filter重建图像建立loss
2)flowloss: filter转换为optical flow, warp重建图像建立loss
3)optical flow L1 正则化loss: 使光流是稀疏的,这个主要用于目标跟踪吧,背景是静止的这种情况,否则光流不应该是稀疏的。

# 对输入 abs sum,相当于L1正则化
class Loss4Laziness(nn.Module):
    def __init__(self, device='cpu', weight=1):
        super(Loss4Laziness, self).__init__()
        self.device = device
        self.weight = weight
        self.sparseMap = 0
        
    def forward(self, X):
        N,C,H,W = X.size()
        self.sparseMap  = torch.abs(X)
        totloss = torch.sum(torch.sum(self.sparseMap,3),2)/(H*W)        
        totloss = torch.sum(torch.sum(totloss,1))/(N*C)
        return totloss*self.weight

4)filter smooth multiscale: 其实就是梯度上的约束

求梯度的kernel,包括水平和竖直: [1,-1], [1,0,-1], [1,0,0,-1], [1,0,0,0,-1]

filter的梯度图是hMap和vMap:

loss=torch.sqrt((self.hMap)**2+self.epsilon**2) + torch.sqrt((self.vMap)**2+self.epsilon**2)

是filter空间域平滑

5)双向一致性损失:
源代码如下,实际是对optical flow做的, 但是直接2个方向的optical flow相加等于0不就行了吗?还是说我的理解有误

# 双向一致性损失,就是grid 通过 optical flow a2b 和 b2a后,应该恢复原grid
class Loss4BidirFlowVec(nn.Module):
    def __init__(self, device='cpu', weight=1):
        super(Loss4BidirFlowVec, self).__init__()
        self.device = device
        self.weight = weight
        self.epsilon = 0.001
    def forward(self, UV_AtoB, UV_BtoA):
        N, C, H, W = UV_AtoB.size()
        # mesh grid 
        xx = torch.arange(0, W).view(1,-1).repeat(H,1)
        yy = torch.arange(0, H).view(-1,1).repeat(1,W)
        xx = xx.view(1,1,H,W).repeat(N,1,1,1)
        yy = yy.view(1,1,H,W).repeat(N,1,1,1)
        GridXY = torch.cat((xx,yy),1).float()
        GridXY = GridXY.to(self.device)
        GridXY = Variable(GridXY)
        #GridXY[:,0,:,:] = 2.0*GridXY[:,0,:,:]/max(W-1,1)-1.0
        #GridXY[:,1,:,:] = 2.0*GridXY[:,1,:,:]/max(H-1,1)-1.0
        GridXY[:,0,:,:] = GridXY[:,0,:,:]/max(W-1,1)
        GridXY[:,1,:,:] = GridXY[:,1,:,:]/max(H-1,1)

        
        # 对 grid 应用warp, 直接加减不就行了?
        mapXY_ABA = self.funcOpticalFlowWarp(GridXY, UV_AtoB.to(self.device), self.device)
        mapXY_ABA = self.funcOpticalFlowWarp(mapXY_ABA, UV_BtoA.to(self.device), self.device)
        self.DiffABA = GridXY-mapXY_ABA
        #self.DiffABA = torch.sqrt(torch.sum(self.DiffABA**2, 1))#.squeeze()
        self.DiffABA = torch.sqrt(torch.sum(self.DiffABA**2, 1)+self.epsilon**2)
        
        totloss_ABA = torch.sum(torch.sum(self.DiffABA,2),1)/(H*W)
        totloss_ABA = torch.sum(torch.sum(totloss_ABA))/N
        

        mapXY_BAB = self.funcOpticalFlowWarp(GridXY, UV_BtoA.to(self.device), self.device)
        mapXY_BAB = self.funcOpticalFlowWarp(mapXY_BAB, UV_AtoB.to(self.device), self.device)
        self.DiffBAB = GridXY-mapXY_BAB
        #self.DiffBAB = torch.sqrt(torch.sum(self.DiffBAB**2, 1))#.squeeze()        
        self.DiffBAB = torch.sqrt(torch.sum(self.DiffBAB**2, 1)+self.epsilon**2)
        
        totloss_BAB = torch.sum(torch.sum(self.DiffBAB,2),1)/(H*W)
        totloss_BAB = torch.sum(torch.sum(totloss_BAB))/N
        
        return (totloss_BAB+totloss_ABA)*self.weight

4.multigrid

sampler = iter(dataloader4videoseg)
recErrorList = []
for i, sample in enumerate(sampler):
    ################## visualize some samples ###################
    imgListA2,imgListB2,imgListA4,imgListB4,imgListA8,imgListB8 = sample[:6]
    imgListA16,imgListB16,imgListA32,imgListB32 = sample[6:]
    
    imgListB2 = imgListA2[-1].unsqueeze(0)
    imgListA2 = imgListA2[0].unsqueeze(0)
    imgListB4 = imgListA4[-1].unsqueeze(0)
    imgListA4 = imgListA4[0].unsqueeze(0)
    imgListB8 = imgListA8[-1].unsqueeze(0)
    imgListA8 = imgListA8[0].unsqueeze(0)
    imgListB16 = imgListA16[-1].unsqueeze(0)
    imgListA16 = imgListA16[0].unsqueeze(0)
    imgListB32 = imgListA32[-1].unsqueeze(0)
    imgListA32 = imgListA32[0].unsqueeze(0)
    
        
        
    imgListA2 = imgListA2.to(device)
    imgListB2 = imgListB2.to(device)
    imgListA4 = imgListA4.to(device)
    imgListB4 = imgListB4.to(device)
    imgListA8 = imgListA8.to(device)
    imgListB8 = imgListB8.to(device)
    imgListA16 = imgListA16.to(device)
    imgListB16 = imgListB16.to(device)
    imgListA32 = imgListA32.to(device)
    imgListB32 = imgListB32.to(device)
    # 推理过程挺复杂的。
    if True:   
        warpImgWithScale16.device = device 
        warpImgWithScale8.device = device 
        warpImgWithScale4.device = device 
        warpImgWithScale2.device = device 
        warpImgWithScale1.device = device 
        
        #  1/32大小
        _, PFFx32_1to2 = curmodel(imgListA32, imgListB32)        
        recImgB32x2 = warpImgWithScale2(imgListA16, PFFx32_1to2)
        recImgB32x2 = recImgB32x2.detach()        
        recImgB32x4 = warpImgWithScale4(imgListA8, PFFx32_1to2)
        recImgB32x4 = recImgB32x4.detach()        
        recImgB32x8 = warpImgWithScale8(imgListA4, PFFx32_1to2)
        recImgB32x8 = recImgB32x8.detach() 
        recImgB32x16 = warpImgWithScale16(imgListA2, PFFx32_1to2)
        recImgB32x16 = recImgB32x16.detach()
        F1 = warpImgWithScale16.UVgrid.detach() # 最小的flow, 缩放到最大的image size
        
        # 1/16大小
        _, PFFx16_1to2 = curmodel(recImgB32x2, imgListB16)
        recImgB16x2 = warpImgWithScale2(recImgB32x4, PFFx16_1to2)
        recImgB16x2 = recImgB16x2.detach()
        recImgB16x4 = warpImgWithScale4(recImgB32x8, PFFx16_1to2)
        recImgB16x4 = recImgB16x4.detach()
        recImgB16x8 = warpImgWithScale8(recImgB32x16, PFFx16_1to2)
        recImgB16x8 = recImgB16x8.detach()
        F2 = warpImgWithScale8.UVgrid.detach()  
                
        _, PFFx8_1to2 = curmodel(recImgB16x2, imgListB8) 
        recImgB8x2 = warpImgWithScale2(recImgB16x4, PFFx8_1to2)
        recImgB8x2 = recImgB8x2.detach()
        recImgB8x4 = warpImgWithScale4(recImgB16x8, PFFx8_1to2)
        recImgB8x4 = recImgB8x4.detach()
        F3 = warpImgWithScale4.UVgrid.detach()  
                
        _, PFFx4_1to2 = curmodel(recImgB8x2, imgListB4) 
        recImgB4x2 = warpImgWithScale2(recImgB8x4, PFFx4_1to2)
        recImgB4x2 = recImgB4x2.detach()
        F4 = warpImgWithScale2.UVgrid.detach() 
                
        _, PFFx2_1to2 = curmodel(recImgB4x2, imgListB2) 
        recImgB2x1 = warpImgWithScale1(recImgB4x2, PFFx2_1to2)
        recImgB2x1 = recImgB2x1.detach()
        F5 = warpImgWithScale1.UVgrid.detach()
        recImg = recImgB2x1
        
        
        warpImgWithScale16.device = supplDevice 
        warpImgWithScale8.device = supplDevice 
        warpImgWithScale4.device = supplDevice 
        warpImgWithScale2.device = supplDevice 
        warpImgWithScale1.device = supplDevice 
        
    F1tmp = F1[0].detach().cpu()
    F2tmp = F2[0].detach().cpu()
    F3tmp = F3[0].detach().cpu()
    F4tmp = F4[0].detach().cpu()
    F5tmp = F5[0].detach().cpu()
    F_fine2coarse = [F5tmp, F4tmp, F3tmp, F2tmp, F1tmp]
    OF4vis = genFlowVector4Visualization(F_fine2coarse)
    recImg,_ = funcOpticalFlowWarp(imgListA2.to('cpu'), OF4vis)
    
    figWinNumHeight, figWinNumWidth = 1, 5
    plt.figure(figsize=(22, 5), dpi=64, facecolor='w', edgecolor='k') # figsize -- inch-by-inch
    subwinCount = 1 
    
    plt.subplot(figWinNumHeight,figWinNumWidth,subwinCount)
    subwinCount += 1
    imgA = imgListA2[0].detach().cpu().numpy().squeeze().transpose((1,2,0)) 
    imgA = (imgA+1)/2
    imgA = imgA.clip(0,1)
    plt.imshow(imgA), plt.axis('off'), plt.title(' frameA')
    
    plt.subplot(figWinNumHeight,figWinNumWidth,subwinCount)
    subwinCount += 1
    imgB = imgListB2[0].detach().cpu().numpy().squeeze().transpose((1,2,0)) 
    imgB = (imgB+1)/2
    imgB = imgB.clip(0,1)
    plt.imshow(imgB), plt.axis('off'), plt.title('frameB')
    
    plt.subplot(figWinNumHeight,figWinNumWidth,subwinCount)
    subwinCount += 1
    img = recImg[0].detach().cpu().numpy().squeeze().transpose((1,2,0)) 
    img = (img+1)/2
    img = img.clip(0,1)
    plt.imshow(img), plt.axis('off'), plt.title('recFrameB')
    
    plt.subplot(figWinNumHeight,figWinNumWidth,subwinCount)
    subwinCount += 1
    UV = OF4vis.detach().cpu().numpy()
    UV = UV/np.abs(UV).max()
    flowVisShow = objDemoShowFlow.computeColor(UV[0], UV[1])/255.    
    plt.imshow(flowVisShow)
    plt.axis('off')
    plt.title('flow A->B')         
    
    
    A = np.abs(imgB*255-img*255)
    H,W,C = A.shape
    A = np.reshape(A, (H*W*C, -1))
    recErrorList+=[np.mean(A)]
    
    
    if saveFigures:
        scipy.misc.imsave(os.path.join(save_dir,format(i,'05d')+'_A.png'), imgA*255)
        scipy.misc.imsave(os.path.join(save_dir,format(i,'05d')+'_B.png'), imgB*255)
        scipy.misc.imsave(os.path.join(save_dir,format(i,'05d')+'_recB.png'), img*255)
        scipy.misc.imsave(os.path.join(save_dir,format(i,'05d')+'_flowMap.png'), flowVisShow*255)

效果:
在这里插入图片描述

  • 17
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值