构建Spatial Transformer Networks

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

二维图像变形(Warp)
变形(Warp)的操作内容就是,在初始化的过程中先生成一个与图像大小相同的网格,即规则网格,如果使用规则网格对图像变形,则可以得到一个与原始图像相同的图像,图像不发生形变,这个过程的可视化可以参考我的另一篇博客。在调用变形函数时,需要提供两个参数,一个是原始图像,一个是变形场(flow-field),将变形场与规则网格相加,然后对原始图像进行重采样。代码的实现,参考一篇称为PWC-Net的论文的开源代码,其中的warp函数进行了修改,具体如下:

class Warper2d(nn.Module):
    def __init__(self, img_size):
        super(Warper2d, self).__init__()
        """
        warp an image/tensor (im2) back to im1, according to the optical flow
#        img_src: [B, 1, H1, W1] (source image used for prediction, size 32)
        img_smp: [B, 1, H2, W2] (image for sampling, size 44)
        flow: [B, 2, H1, W1] flow predicted from source image pair
        """
        self.img_size = img_size
        H, W = img_size, img_size
        # mesh grid 
        xx = torch.arange(0, W).view(1,-1).repeat(H,1)
        yy = torch.arange(0, H).view(-1,1).repeat(1,W)
        xx = xx.view(1,H,W)
        yy = yy.view(1,H,W)
        self.grid = torch.cat((xx,yy),0).float() # [2, H, W]
            
    def forward(self, flow, img):
        grid = self.grid.repeat(flow.shape[0],1,1,1)#[bs, 2, H, W]
        if img.is_cuda:
            grid = grid.cuda()
#        if flow.shape[2:]!=img.shape[2:]:
#            pad = int((img.shape[2] - flow.shape[2]) / 2)
#            flow = F.pad(flow, [pad]*4, 'replicate')#max_disp=6, 32->44
        vgrid = Variable(grid, requires_grad = False) + flow
 
        # scale grid to [-1,1] 
#        vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:]/(W-1)-1.0 #max(W-1,1)
#        vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:]/(H-1)-1.0 #max(H-1,1)
        vgrid = 2.0*vgrid/(self.img_size-1)-1.0 #max(W-1,1)
 
        vgrid = vgrid.permute(0,2,3,1)        
        output = F.grid_sample(img, vgrid)
#        mask = Variable(torch.ones(img.size())).cuda()
#        mask = F.grid_sample(mask, vgrid)
#        
#        mask[mask<0.9999] = 0
#        mask[mask>0] = 1
        
        return output#*mask

三维图像变形(Warp)
根据二维图像变形操作的思路,将其拓展到了三维图像配准上,以下是实现代码:

class Warper3d(nn.Module):
    def __init__(self, img_size):
        super(Warper3d, self).__init__()
        """
        warp an image, according to the optical flow
        image: [B, 1, D, H, W] image for sampling
        flow: [B, 3, D, H, W] flow predicted from source image pair
        """
        self.img_size = img_size
        D, H, W = img_size
        # mesh grid 
        xx = torch.arange(0, W).view(1,1,-1).repeat(D,H,1).view(1,D,H,W)
        yy = torch.arange(0, H).view(1,-1,1).repeat(D,1,W).view(1,D,H,W)
        zz = torch.arange(0, D).view(-1,1,1).repeat(1,H,W).view(1,D,H,W)
        self.grid = torch.cat((xx,yy,zz),0).float() # [3, D, H, W]
            
    def forward(self, img, flow):
        grid = self.grid.repeat(flow.shape[0],1,1,1,1)#[bs, 3, D, H, W]
#        mask = torch.ones(img.size())
        if img.is_cuda:
            grid = grid.cuda()
#            mask = mask.cuda()
        vgrid = grid + flow
 
        # scale grid to [-1,1]
        D, H, W = self.img_size
        vgrid[:,0] = 2.0*vgrid[:,0]/(W-1)-1.0 #max(W-1,1)
        vgrid[:,1] = 2.0*vgrid[:,1]/(H-1)-1.0 #max(H-1,1)
        vgrid[:,2] = 2.0*vgrid[:,2]/(D-1)-1.0 #max(H-1,1)
 
        vgrid = vgrid.permute(0,2,3,4,1)#[bs, D, H, W, 3]        
        output = F.grid_sample(img, vgrid, padding_mode='border')#, mode='nearest'
#        mask = F.grid_sample(mask, vgrid)#, mode='nearest'        
#        mask[mask<0.9999] = 0
#        mask[mask>0] = 1
        
        return output#*mask
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值