pytorch上分之路——视频补全算法(onion peel network)


前言

该算法是从github上找的onion peel network算法,但是由于开发者只提供了demo部分,所以我试着自己把train的部分自己实现了,目前来看多少有点能补全的意思。目前来看还不是很成熟,但我还是发出来给大家看看。当然我把风格严格控制成我以前发的代码风格,方便学习和以后的优化

一、config.py

import argparse

parser=argparse.ArgumentParser(description="Onion Peel Network")

parser.add_argument('--project_name',type=str,default="video completion by Onion Peel Network",
                    help='工程名')
parser.add_argument("--use_cuda",type=bool,default=True,
                    help="是否想使用cuda")
parser.add_argument("--seed",type=int,default=123,
                    help="随机种子")
parser.add_argument("--resume",type=bool,default=True,
                    help="是否使用预训练的权重加载模型")
parser.add_argument("--pretrained_weight",type=str,default='OPN.pth',
                    help="预训练模型加载路径")
parser.add_argument("--lr",type=float,default=0.0001,
                    help="学习率")
parser.add_argument("--weight_decay",type=float,default=1e-4,
                    help="权重衰减系数")
parser.add_argument("--momentum",type=float,default=0.5,
                    help="动量系数")
parser.add_argument("--epoch",type=int,default=10,
                    help="训练epoch次数")
parser.add_argument("--train_batch_size",type=int,default=1,
                    help="训练batch_size")
parser.add_argument("--test_batch_size",type=int,default=1,
                    help="测试batch_szie")
parser.add_argument("--save",type=bool,default=True,
                    help="保存图片")

这是我工程的配置文件,跟以前的一样

二、datalist.py

import os
import random

import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset

T, H, W = 5, 240, 424


class Dataset(Dataset):
    def __init__(self, type='train'):
        self.type = type

    def __len__(self):
        return len(os.listdir('Image_inputs/W')) // 2

    def __getitem__(self, index):
        if index >= len(self) - 5:
            index = index - 5

        print(index)

        frames = np.empty((T, H, W, 3), dtype=np.float32)
        holes = np.empty((T, H, W, 1), dtype=np.float32)
        dists = np.empty((T, H, W, 1), dtype=np.float32)

        for i in range(5):
            # rgb
            img_file = os.path.join('Image_inputs', 'W', '{:04d}.jpg'.format(index + i))
            raw_frame = np.array(Image.open(img_file).convert('RGB')) / 255.
            raw_frame = cv2.resize(raw_frame, dsize=(W, H), interpolation=cv2.INTER_CUBIC)
            frames[i] = raw_frame
            # mask
            mask_file = os.path.join('Image_inputs', 'W', '{:04d}.png'.format(index + i))
            raw_mask = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)
            raw_mask = (raw_mask > 0.5).astype(np.uint8)
            raw_mask = cv2.resize(raw_mask, dsize=(W, H), interpolation=cv2.INTER_NEAREST)
            raw_mask = cv2.dilate(raw_mask, cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)))  # cv2.dilate 膨胀操作
            holes[i, :, :, 0] = raw_mask.astype(np.float32)
            # dists
            dists[i, :, :, 0] = cv2.distanceTransform(raw_mask, cv2.DIST_L2,
                                                      maskSize=5)  # cv2.distanceTransform()可以方便地将前景对象提取出来
        # 图片转换成tensor
        frames = torch.from_numpy(np.transpose(frames, (3, 0, 1, 2)).copy()).float()
        holes = torch.from_numpy(np.transpose(holes, (3, 0, 1, 2)).copy()).float()
        dists = torch.from_numpy(np.transpose(dists, (3, 0, 1, 2)).copy()).float()

        # remove holes 在图片中抠出相对应的洞   [0.4585, 0.456, 0.406]使用了imageNet的平均值
        frames = frames * (1 - holes) + holes * torch.tensor([0.4585, 0.456, 0.406]).view(3, 1, 1, 1)
        # valids area  验证标签,也就是获取图片中被扣掉部分原有的数值分布,用于后续的loss计算
        valids = 1 - holes

        # frames = frames.unsqueeze(0)
        # holes = holes.unsqueeze(0)
        # dists = dists.unsqueeze(0)
        # valids = valids.unsqueeze(0)

        return frames, valids, dists


class Generator(object):
    def __init__(self, batch_size=1):
        self.batch_size = batch_size
        self.images = os.listdir("Image_inputs")

    def generator(self):
        # while True:
        dir = self.images[random.choice(range(len(self.images)))]

        frames = np.empty((T, H, W, 3), dtype=np.float32)
        holes = np.empty((T, H, W, 1), dtype=np.float32)
        dists = np.empty((T, H, W, 1), dtype=np.float32)
        label = np.empty((T, H, W, 3), dtype=np.float32)

        for i in range(5):
            # rgb
            img_file = os.path.join('Image_inputs', dir, 'gt_{:1d}.jpg'.format(i))
            raw_frame = np.array(Image.open(img_file).convert('RGB')) / 255.
            raw_frame = cv2.resize(raw_frame, dsize=(W, H), interpolation=cv2.INTER_CUBIC)
            frames[i] = raw_frame
            label[i] = raw_frame
            # mask
            mask_file = os.path.join('Image_inputs', dir, 'mask_{:1d}.png'.format(i))
            raw_mask = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)
            raw_mask = (raw_mask > 0.5).astype(np.uint8)
            raw_mask = cv2.resize(raw_mask, dsize=(W, H), interpolation=cv2.INTER_NEAREST)
            raw_mask = cv2.dilate(raw_mask, cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)))  # cv2.dilate 膨胀操作
            holes[i, :, :, 0] = raw_mask.astype(np.float32)
            # dists
            dists[i, :, :, 0] = cv2.distanceTransform(raw_mask, cv2.DIST_L2,
                                                      maskSize=5)  # cv2.distanceTransform()可以方便地将前景对象提取出来
        # 图片转换成tensor
        frames = torch.from_numpy(np.transpose(frames, (3, 0, 1, 2)).copy()).float()
        holes = torch.from_numpy(np.transpose(holes, (3, 0, 1, 2)).copy()).float()
        dists = torch.from_numpy(np.transpose(dists, (3, 0, 1, 2)).copy()).float()
        label = torch.from_numpy(np.transpose(label, (3, 0, 1, 2)).copy()).float()

        # remove holes 在图片中抠出相对应的洞   [0.4585, 0.456, 0.406]使用了imageNet的平均值
        frames = frames * (1 - holes) + holes * torch.tensor([0.4585, 0.456, 0.406]).view(3, 1, 1, 1)
        # valids area  验证标签,也就是获取图片中被扣掉部分原有的数值分布,用于后续的loss计算
        valids = 1 - holes

        frames = frames.unsqueeze(0)
        dists = dists.unsqueeze(0)
        valids = valids.unsqueeze(0)
        label = label.unsqueeze(0)

        yield frames, valids, dists, label


class Dataset2(Dataset):
    def __init__(self):
        self.images = os.listdir("image")

    def __len__(self):
        return 1

    def __getitem__(self, index):
        dir = self.images[random.choice(range(len(self.images)))]

        frames = np.empty((T, H, W, 3), dtype=np.float32)
        holes = np.empty((T, H, W, 1), dtype=np.float32)
        dists = np.empty((T, H, W, 1), dtype=np.float32)
        label = np.empty((T, H, W, 3), dtype=np.float32)

        for i in range(5):
            # rgb
            img_file = os.path.join('image', dir, 'gt_{:04d}.jpg'.format(random.choice(range(len(os.listdir(os.path.join('image', dir)))))))
            raw_frame = np.array(Image.open(img_file).convert('RGB')) / 255.
            raw_frame = cv2.resize(raw_frame, dsize=(W, H), interpolation=cv2.INTER_CUBIC)
            frames[i] = raw_frame
            label[i] = raw_frame
            # mask
            mask_file = os.path.join('mask', 'mask_{:04d}.png'.format(random.choice(range(0,52))))
            raw_mask = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)
            raw_mask = (raw_mask > 0.5).astype(np.uint8)
            raw_mask = cv2.resize(raw_mask, dsize=(W, H), interpolation=cv2.INTER_NEAREST)
            raw_mask = cv2.dilate(raw_mask, cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)))  # cv2.dilate 膨胀操作
            holes[i, :, :, 0] = raw_mask.astype(np.float32)
            # dists
            dists[i, :, :, 0] = cv2.distanceTransform(raw_mask, cv2.DIST_L2,
                                                      maskSize=5)  # cv2.distanceTransform()可以方便地将前景对象提取出来
            # 图片转换成tensor
        frames = torch.from_numpy(np.transpose(frames, (3, 0, 1, 2)).copy()).float()
        holes = torch.from_numpy(np.transpose(holes, (3, 0, 1, 2)).copy()).float()
        dists = torch.from_numpy(np.transpose(dists, (3, 0, 1, 2)).copy()).float()
        label = torch.from_numpy(np.transpose(label, (3, 0, 1, 2)).copy()).float()

        # remove holes 在图片中抠出相对应的洞   [0.4585, 0.456, 0.406]使用了imageNet的平均值
        frames = frames * (1 - holes) + holes * torch.tensor([0.4585, 0.456, 0.406]).view(3, 1, 1, 1)
        # valids area  验证标签,也就是获取图片中被扣掉部分原有的数值分布,用于后续的loss计算
        valids = 1 - holes

        return frames, valids, dists, label

这里的data部分,因为需求不一样了,我略微的做了一点改变。用的是generator部分,前面几个data是失败部分。

三.common.py

import torch
import torch.nn as nn


def get_features(img, model, layers=None):
    '''获取特征层'''
    if layers is None:
        layers = {
            '0': 'conv1_1',
            '5': 'conv2_1',
            '10': 'conv3_1',
            '19': 'conv4_1',
            '21': 'conv4_2',  # content层
            '28': 'conv5_1'
        }

    features = {}
    x = img
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x

    return features


def gram_matrix(tensor):
    '''计算Gram matrix'''
    _, d, h, w = tensor.size()  # 第一个是batch_size

    tensor = tensor.view(d, h * w)

    gram = torch.mm(tensor, tensor.t())

    return gram


'''
TV loss是常用的一种正则项(注意是正则项,配合其他loss一起使用,约束噪声)
图片中相邻像素值的差异可以通过降低TV loss来一定程度上解决

图像上的一点点噪声可能就会对复原的结果产生非常大的影响,因为很多复原算法都会放大噪声。
这时候我们就需要在最优化问题的模型中添加一些正则项来

保持图像的光滑性
'''

class TVLoss(nn.Module):
    def __init__(self, TVLoss_weight=1):
        super(TVLoss, self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self._tensor_size(x[:, :, 1:, :])  # 算出总共求了多少次差
        count_w = self._tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        # x[:,:,1:,:]-x[:,:,:h_x-1,:]就是对原图进行错位,分成两张像素位置差1的图片,第一张图片
        # 从像素点1开始(原图从0开始),到最后一个像素点,第二张图片从像素点0开始,到倒数第二个
        # 像素点,这样就实现了对原图进行错位,分成两张图的操作,做差之后就是原图中每个像素点与相
        # 邻的下一个像素点的差。
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.TVLoss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    def _tensor_size(self, t):
        return t.size()[1] * t.size()[2] * t.size()[3]


class L1_Loss(nn.Module):
    def __init__(self):
        super(L1_Loss, self).__init__()

    def forward(self, x, y, pv):
        loss = 0
        pv = pv
        if pv.ndim > 4:
            pv = pv
            pv = pv.squeeze(dim=0)
            for i in range(pv.size(1)):
                loss += torch.sum(torch.abs(x - y) * pv[:, i, :, :])
            return loss
        else:
            loss = torch.sum(torch.abs(x - y) * pv)/(y.shape(0)*y.shape(1)*y.shape(2))
            return loss

class L1_Lossv2(nn.Module):
    def __init__(self):
        super(L1_Lossv2, self).__init__()

    def forward(self,x,y,pv):
        loss=0
        if pv.ndim>4:
            for i in range(pv.size()[1]):
                temp=pv[:,:,i]
                loss=loss+torch.sum((torch.abs(x.flatten()-y.flatten())*temp.flatten()))
            return loss
        else:
            loss=torch.sum(torch.abs(x.flatten()-y.flatten())*pv.flatten())
            return loss

def L1(x,y,mask):
    res=torch.abs(x-y)
    res=res*mask
    return torch.sum(res)/(y.shape(0)*y.shape(1)*y.shape(2))

def ll1(x,y):
    return torch.sum(x-y)



这里主要是自己写loss,因为论文里他loss挺多的,相关博文介绍又少,我花了好多时间去测试,但自己写loss真的很重要。毕竟现在顶会论文花样越来越多了,已经不局限于网络结构设计都看不懂了,真不想搞了,cnn打天下不好吗

四.model.py

from __future__ import division

# general libs
import math
import sys

sys.path.insert(0, '.')
from .common import *

sys.path.insert(0, '../utils/')
from utils.helpers import *


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv12 = GatedConv2d(5, 64, kernel_size=5, stride=2, padding=2,
                                  activation=nn.LeakyReLU(negative_slope=0.2))  # 2
        self.conv2 = GatedConv2d(64, 64, kernel_size=3, stride=1, padding=1,
                                 activation=nn.LeakyReLU(negative_slope=0.2))  # 2
        self.conv23 = GatedConv2d(64, 128, kernel_size=3, stride=2, padding=1,
                                  activation=nn.LeakyReLU(negative_slope=0.2))  # 4
        self.conv3a = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=1,
                                  activation=nn.LeakyReLU(negative_slope=0.2))  # 4
        self.conv3b = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=2, dilation=2,
                                  activation=nn.LeakyReLU(negative_slope=0.2))  # 4
        self.conv3c = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=4, dilation=4,
                                  activation=nn.LeakyReLU(negative_slope=0.2))  # 4
        self.conv3d = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=8, dilation=8,
                                  activation=nn.LeakyReLU(negative_slope=0.2))  # 4
        self.key3 = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=None)  # 4
        self.val3 = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=None)  # 4

        self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, in_f, in_v, in_h):
        # 图片标准化 frames
        f = (in_f - self.mean) / self.std
        x = torch.cat([f, in_v, in_h], dim=1)
        x = self.conv12(x)
        x = self.conv2(x)
        x = self.conv23(x)
        x = self.conv3a(x)
        x = self.conv3b(x)
        x = self.conv3c(x)
        x = self.conv3d(x)
        k = self.key3(x)
        v = self.val3(x)
        return k, v


# 不对称注意力模块
class MaskedRead(nn.Module):
    def __init__(self):
        super(MaskedRead, self).__init__()

    def forward(self, qkey, qval, qmask, mkey, mval, mmask):
        '''
        read for *mask area* of query from *mask area* of memory
        '''

        B, Dk, _, H, W = mkey.size()
        _, Dv, _, _, _ = mval.size()
        # key: b,dk,t,h,w
        # value: b,dv,t,h,w
        # mask: b,1,t,h,w
        for b in range(B):
            # exceptions
            if qmask[b, 0].sum() == 0 or mmask[b, 0].sum() == 0:
                # print('skipping read', qmask[b,0].sum(), mmask[b,0].sum())
                # no query or mask pixels -> skip read
                continue
            # [128,284]
            qk_b = qkey[b, :, qmask[b, 0]]  # dk, Nq
            mv_b = mval[b, :, mmask[b, 0]]  # dv, Nm
            mk_b = mkey[b, :, mmask[b, 0]]  # dk, Nm   #mkey(1,128,4,60,106)  mmask(1,1,5,60,106)

            # print(mv_b.shape)

            p = torch.mm(torch.transpose(mk_b, 0, 1), qk_b)  # Nm, Nq
            p = p / math.sqrt(Dk)  # 防止过拟合
            p = torch.softmax(p, dim=0)

            read = torch.mm(mv_b, p)  # dv, Nq
            # qval[b,:,qmask[b,0]] = read # dv, Nq
            qval[b, :, qmask[b, 0]] = qval[b, :, qmask[b, 0]] + read  # dv, Nq

        return qval


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.conv3d = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=8, dilation=8,
                                  activation=nn.LeakyReLU(negative_slope=0.2))  # 4
        self.conv3c = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=4, dilation=4,
                                  activation=nn.LeakyReLU(negative_slope=0.2))  # 4
        self.conv3b = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=2, dilation=2,
                                  activation=nn.LeakyReLU(negative_slope=0.2))  # 4
        self.conv3a = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=1,
                                  activation=nn.LeakyReLU(negative_slope=0.2))  # 4
        self.conv32 = GatedConv2d(128, 64, kernel_size=3, stride=1, padding=1,
                                  activation=nn.LeakyReLU(negative_slope=0.2))  # 2
        self.conv2 = GatedConv2d(64, 64, kernel_size=3, stride=1, padding=1,
                                 activation=nn.LeakyReLU(negative_slope=0.2))  # 2
        self.conv21 = GatedConv2d(64, 3, kernel_size=5, stride=1, padding=2, activation=None)  # 1


        self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, x):
        x = self.conv3d(x)
        x = self.conv3c(x)
        x = self.conv3b(x)
        x = self.conv3a(x)
        x = F.interpolate(x, scale_factor=2, mode='nearest')  # 2
        x = self.conv32(x)
        x = self.conv2(x)
        x = F.interpolate(x, scale_factor=2, mode='nearest')  # 2
        x = self.conv21(x)
        p = (x * self.std) + self.mean
        return p


class OPN(nn.Module):
    def __init__(self, mode='Train', CPU_memory=False, thickness=8):
        super(OPN, self).__init__()
        self.Encoder = Encoder()
        self.MaskedRead = MaskedRead()
        self.Decoder = Decoder()

        self.thickness = thickness

        self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('mean3d', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1, 1))

    def memorize(self, frames, valids, dists):
        '''
        encode every frame of *valid* area into key:value
        Done once as initialization
        '''

        # padding到同一尺寸
        (frames, valids, dists), pad = pad_divide_by([frames, valids, dists], 4, (frames.size()[3], frames.size()[4]))

        # make hole
        holes = (dists > 0).float()
        frames = (1 - holes) * frames + holes * self.mean3d
        batch_size, _, num_frames, height, width = frames.size()
        # num_frames 图片的张数 :5张
        # encoding...
        key_ = []
        val_ = []
        for t in range(num_frames):
            key, val = self.Encoder(frames[:, :, t], valids[:, :, t], holes[:, :, t])
            key_.append(key)
            val_.append(val)

        keys = torch.stack(key_, dim=2)
        vals = torch.stack(val_, dim=2)

        hols = (F_upsample3d(holes, size=(int(height / 4), int(width / 4)), mode='bilinear', align_corners=False) > 0)
        return keys, vals, hols

    def read(self, mkey, mval, mhol, frame, valid, dist):
        ''' 
        ## assume single frame query
        1) encode current status of frames -> query
        2) read from memmories (computed calling 'memorize')
        3) decode readed feature
        4) compute loss on peel area
        '''
        thickness = self.thickness

        # padding
        (frame, valid, dist), pad = pad_divide_by([frame, valid, dist], 4, (frame.size()[2], frame.size()[3]))
        batch_size, _, height, width = frame.size()
        # make hole and peel..
        hole = (dist > 0).float()
        peel = hole * (dist <= thickness).float()
        # 更新dist
        next_dist = torch.clamp(dist - thickness, 0, 9999)
        # get 1/4 scale mask
        peel3 = (F.upsample(peel, size=(int(height / 4), int(width / 4)), mode='bilinear', align_corners=False) >= 0.5)
        # 更新frame
        frame = (1 - hole) * frame + hole * self.mean

        # reading and decoding...
        qkey, qval = self.Encoder(frame, valid, hole)
        qpel = peel3
        # read 不对称注意力块.
        read = self.MaskedRead(qkey, qval, qpel, mkey, mval, ~mhol)
        # decode
        pred = self.Decoder(read)
        comp = (1 - peel) * frame + peel * pred  # fill peel area

        if pad[2] + pad[3] > 0:
            comp = comp[:, :, pad[2]:-pad[3], :]
            next_dist = next_dist[:, :, pad[2]:-pad[3], :]
        if pad[0] + pad[1] > 0:
            comp = comp[:, :, :, pad[0]:-pad[1]]
            next_dist = next_dist[:, :, :, pad[0]:-pad[1]]
        # 防止颜色通道信息的溢出
        comp = torch.clamp(comp, 0, 1)
        return comp, next_dist, peel

    def forward(self, *args, **kwargs):
        # print(len(args), len(kwargs))
        if len(args) == 3:
            return self.memorize(*args)
        else:
            return self.read(*args, **kwargs)

这里是论文中的相关部分,说实话这技术太厉害了,一般人根本写不出来,全是抄的就不说了

五.model_common.py

from __future__ import division

# general libs
import sys
import torch.nn.functional as F
sys.path.insert(0, '../utils/')
from utils.helpers import *


##########################################
############   Generic   #################
##########################################

def pad_divide_by(in_list, d, in_size):
    out_list = []
    h, w = in_size
    if h % d > 0:
        new_h = h + d - h % d
    else:
        new_h = h
    if w % d > 0:
        new_w = w + d - w % d
    else:
        new_w = w
    lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
    lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
    pad_array = (int(lw), int(uw), int(lh), int(uh))
    for inp in in_list:
        out_list.append(F.pad(inp, pad_array))
    return out_list, pad_array


class ConvGRU(nn.Module):
    def __init__(self, mdim, kernel_size=3, padding=1):
        super(ConvGRU, self).__init__()
        self.convIH = nn.Conv2d(mdim, 3 * mdim, kernel_size=kernel_size, padding=padding)
        self.convHH = nn.Conv2d(mdim, 3 * mdim, kernel_size=kernel_size, padding=padding)

    def forward(self, input, hidden_tm1):
        if hidden_tm1 is None:
            hidden_tm1 = torch.zeros_like(input)
        gi = self.convIH(input)
        gh = self.convHH(hidden_tm1)
        i_r, i_i, i_n = torch.chunk(gi, 3, dim=1)
        h_r, h_i, h_n = torch.chunk(gh, 3, dim=1)
        resetgate = torch.sigmoid(i_r + h_r)  # reset
        inputgate = torch.sigmoid(i_i + h_i)  # update
        newgate = F.tanh(i_n + resetgate * h_n)
        # hidden_t = inputgate * hidden_tm1 + (1-inputgate)*newgate
        hidden_t = newgate + inputgate * (hidden_tm1 - newgate)
        return hidden_t


def F_upsample3d(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
    num_frames = x.size()[2]
    up_s = []
    for f in range(num_frames):
        up = F.interpolate(x[:, :, f], size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
        up_s.append(up)
    ups = torch.stack(up_s, dim=2)
    return ups


def F_upsample(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
    if x.dim() == 5:  # 3d
        return F_upsample3d(x, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
    else:
        return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)


class GatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, activation=None):
        super().__init__()
        self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
                                    stride, padding, dilation, groups, bias)
        self.gating_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
                                     stride, padding, dilation, groups, bias)
        init_He(self)
        self.activation = activation

    def forward(self, input):
        # O = act(Feature) * sig(Gating)
        feature = self.input_conv(input)
        if self.activation:
            feature = self.activation(feature)
        gating = torch.sigmoid(self.gating_conv(input))
        return feature * gating

这里是模型的相关配置文件,也是抄的

六.train.py

import os
import sys
from tqdm import tqdm
import torch.backends.cudnn as cudnn
import torch.optim as optim
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import models

from common import *
from datalist import Dataset2
from models.OPN import OPN
from utils.helpers import *

sys.path.append('utils/')
sys.path.append('models/')

style_weights = {
    'conv1_1': 1,
    'conv2_1': 0.8,
    'conv3_1': 0.5,
    'conv4_1': 0.3,
    'conv5_1': 0.1,
}
from config import parser


class train(object):
    def __init__(self):
        self.args = parser.parse_args()
        print(f"-----------{self.args.project_name}-----------")
        use_cuda = self.args.use_cuda and torch.cuda.is_available()
        if use_cuda:
            torch.cuda.manual_seed(self.args.seed)
        else:
            torch.manual_seed(self.args.seed)

        self.device = torch.device("cuda" if use_cuda else "cpu")

        kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}

        '''
        构造DataLoader
        '''
        # ToDo 数据集需要重新制备
        print("Create Dataloader")
        self.train_loader = DataLoader(Dataset2(), batch_size=1, shuffle=True, **kwargs)
        self.test_loader = DataLoader(Dataset2(), batch_size=1, shuffle=True, **kwargs)
        '''
        定义模型
        '''
        print("Create Model")
        self.model = OPN().to(self.device)
#        self.model = nn.DataParallel(OPN())
        if use_cuda:
            # self.model = self.model.cuda()
            cudnn.benchmark = True
        '''
        根据需要加载预训练的模型权重参数
        '''

        # VGG16模型配合预训练的模型用于检测
        self.vgg = models.vgg16(pretrained=True).to(self.device).features

        for i in self.vgg.parameters():
            i.requires_grad = False
        try:
            if self.args.resume and self.args.pretrained_weight:
                self.model.load_state_dict(torch.load(os.path.join('OPN.pth')), strict=False)
                print("模型加载成功")
        except:
            print("模型加载失败")
        '''
        cuda加速
        '''
        if use_cuda:
         #   self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
            cudnn.benchmark = True

        '''
        构造loss目标函数
        选择优化器
        学习率变化选择
        '''
        print("Establish the loss, optimizer and learning_rate function")
        self.loss_tv = TVLoss()
        self.loss_l1=L1_Loss()
        # 另外还有style—loss 和 content—loss
        # self.optimizer = optim.SGD(
        #     params=self.model.parameters(),
        #     lr=self.args.lr,
        #     weight_decay=self.args.weight_decay,
        #     momentum=0.5
        # )
        self.optimizer = optim.Adam(
            params=self.model.parameters(),
            lr=0.001,
            betas=(0.9, 0.999),
            eps=1e-8,  # 为了防止分母为0
            weight_decay=0
        )
        # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=5, eta_min=1e-5)
        '''
        模型开始训练
        '''
        print("Start training")
        for epoch in tqdm(range(1, self.args.epoch + 1)):
            self.train(epoch)
            if epoch % 20==0:
            	self.test(epoch)
            
        torch.cuda.empty_cache()

        print("finish model training")

    def train(self, epoch):
        self.model.train()
        for data in self.train_loader:
            
            self.content_loss = 0
            self.style_loss = 0

            midx = list(range(0, 5))
            # frames被破损的图像,valids可获取的像素区域,dists填补的像素区域
            frames, valids, dists, label = data
            frames, valids, dists, label = frames.to(self.device), valids.to(self.device), dists.to(
                self.device), label.to(self.device)
            # 每一张图片都被encoder过了获得的key和val shape为(1,128,5,60,106),hol为(1,1,5,60,106)
            mkey, mval, mhol = self.model(frames[:, :, midx], valids[:, :, midx], dists[:, :, midx])

           
            allloss=0
            for f in range(5):
                loss=0
                # 对每张图取其他4张图作为reference的参考
                ridx = [i for i in range(len(midx)) if i != f]
                fkey, fval, fhol = mkey[:, :, ridx], mval[:, :, ridx], mhol[:, :, ridx]
                # 图像补全
                for r in range(5):
                    if r == 0:
                        # 取主图
                        comp = frames[:, :, f]
                        dist = dists[:, :, f]
                    # comp是破损的图片,逐层补全图片
                    # valids是没有缺失信息的区域
                    # dist是缺失信息的区域
                    '''
                    按dist的指导,逐8个像素的距离,循环修复图片,其中valids表示空洞部分的区域(0,1)
                    comp是在frame的基础之上补充的,相似度极高,只计算这一部分的loss
                    '''
                    comp, dist, peel = self.model(fkey, fval, fhol, comp, valids[:, :, f], dist)
                    # 每次循环中分别在像素空间和深层特征空间最小化和GT的L1距离。
                    loss += 100 * L1(comp, label[:, :, f], peel)
                   # loss += L1(comp, label[:, :, f], valids[:,:,f])
                    loss+=0.2*self.loss_l1(comp,label[:,:,f],valids[:,:,midx])
                    # loss+=100*ll1(comp,frames[:,:,f])


                # content loss
                content_features = get_features(frames[:, :, f], self.vgg)
                target_features = get_features(comp, self.vgg)
                self.content_loss = torch.mean(
                    torch.abs((target_features['conv4_2'] - content_features['conv4_2'])))
                loss = loss + 0.05 * self.content_loss
                # style loss
                style_features = get_features(comp, self.vgg)
                style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
                '''加上每一层的gram_matrix矩阵的损失'''
                for layer in style_weights:
                    target_feature = target_features[layer]
                    target_gram = gram_matrix(target_feature)
                    _, d, h, w = target_feature.shape
                    style_gram = style_grams[layer]
                    layer_style_loss = style_weights[layer] * torch.mean(torch.abs((target_gram - style_gram)))
                    self.style_loss += layer_style_loss / (d * h * w)  # 加到
                loss = loss + 120 * self.style_loss
                # tv loss
                loss += 0.01 * self.loss_tv(comp)
                allloss+=loss
            self.optimizer.zero_grad()
            allloss.backward()
            self.optimizer.step()
        # self.scheduler.step()
       # print("epoch{}".format(epoch) + "  loss:{}".format(loss.cpu()))

    def test(self, epoch):
        self.model.eval()
        for frames, valids, dists, _ in self.test_loader:
            midx = list(range(0, 5))
            # frames, valids, dists = data
            frames, valids, dists = frames.to(self.device), valids.to(self.device), dists.to(self.device)
            with torch.no_grad():
                # 先把这5张图片都encoder一下
                mkey, mval, mhol = self.model(frames[:, :, midx], valids[:, :, midx], dists[:, :, midx])
            # 对每张图取其他4张图作为reference的参考
            for f in range(5):
                ridx = [i for i in range(len(midx)) if i != f]
                fkey, fval, fhol = mkey[:, :, ridx], mval[:, :, ridx], mhol[:, :, ridx]
                # 图像补全
                for r in range(999):
                    if r == 0:
                        comp = frames[:, :, f]
                        dist = dists[:, :, f]
                    with torch.no_grad():
                        comp, dist,peel = self.model(fkey, fval, fhol, comp, valids[:, :, f], dist)

                    comp, dist = comp.detach(), dist.detach()
                    # 空隙填满进入后,把图片保存,然后进入下一轮图片的计算过程中
                    if torch.sum(dist).item() == 0:
                        break

                if self.args.save:

                    # visualize..
                    est = (comp[0].permute(1, 2, 0).detach().cpu().numpy() * 255.).astype(np.uint8)
                    true = (frames[0, :, f].permute(1, 2, 0).detach().cpu().numpy() * 255.).astype(np.uint8)  # h,w,3
                    mask = (dists[0, 0, f].detach().cpu().numpy() > 0).astype(np.uint8)  # h,w,1
                    ov_true = overlay_davis(true, mask, colors=[[0, 0, 0], [100, 100, 0]], cscale=2, alpha=0.4)

                    canvas = np.concatenate([ov_true, est], axis=0)
                    save_path = os.path.join('Results')
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    canvas = Image.fromarray(canvas)
                    canvas.save(os.path.join(save_path, 'res_{}_{}.jpg'.format(epoch, f)))

       # print("epoch{}".format(epoch) + " test finished")


if __name__ == "__main__":
    train()

这里就是我自己写的train,主要还是借鉴了demo的相关内容,我自己测试了一下训练出来的图片虽然效果不佳,但是层次感和色块的分布到有一点感觉了。
毕竟跟作者在论文中提到的一样,v100这种设备普通人家都没有的,我现在感觉2080都是垃圾。当然我能运行,是因为我有机会能接触到设备呗。

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

总结

上面的图片我只跑了半个小时,论文中用到了learning-scheduler,batch—size=5,并且跑了4天。从上面的途中可以看出模型能大概估计出损失部分的轮廓应该长什么样子.我觉得需要继续深入,复现应该可以做到

  • 4
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值