前言
该算法是从github上找的onion peel network算法,但是由于开发者只提供了demo部分,所以我试着自己把train的部分自己实现了,目前来看多少有点能补全的意思。目前来看还不是很成熟,但我还是发出来给大家看看。当然我把风格严格控制成我以前发的代码风格,方便学习和以后的优化
一、config.py
import argparse
parser=argparse.ArgumentParser(description="Onion Peel Network")
parser.add_argument('--project_name',type=str,default="video completion by Onion Peel Network",
help='工程名')
parser.add_argument("--use_cuda",type=bool,default=True,
help="是否想使用cuda")
parser.add_argument("--seed",type=int,default=123,
help="随机种子")
parser.add_argument("--resume",type=bool,default=True,
help="是否使用预训练的权重加载模型")
parser.add_argument("--pretrained_weight",type=str,default='OPN.pth',
help="预训练模型加载路径")
parser.add_argument("--lr",type=float,default=0.0001,
help="学习率")
parser.add_argument("--weight_decay",type=float,default=1e-4,
help="权重衰减系数")
parser.add_argument("--momentum",type=float,default=0.5,
help="动量系数")
parser.add_argument("--epoch",type=int,default=10,
help="训练epoch次数")
parser.add_argument("--train_batch_size",type=int,default=1,
help="训练batch_size")
parser.add_argument("--test_batch_size",type=int,default=1,
help="测试batch_szie")
parser.add_argument("--save",type=bool,default=True,
help="保存图片")
这是我工程的配置文件,跟以前的一样
二、datalist.py
import os
import random
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
T, H, W = 5, 240, 424
class Dataset(Dataset):
def __init__(self, type='train'):
self.type = type
def __len__(self):
return len(os.listdir('Image_inputs/W')) // 2
def __getitem__(self, index):
if index >= len(self) - 5:
index = index - 5
print(index)
frames = np.empty((T, H, W, 3), dtype=np.float32)
holes = np.empty((T, H, W, 1), dtype=np.float32)
dists = np.empty((T, H, W, 1), dtype=np.float32)
for i in range(5):
# rgb
img_file = os.path.join('Image_inputs', 'W', '{:04d}.jpg'.format(index + i))
raw_frame = np.array(Image.open(img_file).convert('RGB')) / 255.
raw_frame = cv2.resize(raw_frame, dsize=(W, H), interpolation=cv2.INTER_CUBIC)
frames[i] = raw_frame
# mask
mask_file = os.path.join('Image_inputs', 'W', '{:04d}.png'.format(index + i))
raw_mask = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)
raw_mask = (raw_mask > 0.5).astype(np.uint8)
raw_mask = cv2.resize(raw_mask, dsize=(W, H), interpolation=cv2.INTER_NEAREST)
raw_mask = cv2.dilate(raw_mask, cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))) # cv2.dilate 膨胀操作
holes[i, :, :, 0] = raw_mask.astype(np.float32)
# dists
dists[i, :, :, 0] = cv2.distanceTransform(raw_mask, cv2.DIST_L2,
maskSize=5) # cv2.distanceTransform()可以方便地将前景对象提取出来
# 图片转换成tensor
frames = torch.from_numpy(np.transpose(frames, (3, 0, 1, 2)).copy()).float()
holes = torch.from_numpy(np.transpose(holes, (3, 0, 1, 2)).copy()).float()
dists = torch.from_numpy(np.transpose(dists, (3, 0, 1, 2)).copy()).float()
# remove holes 在图片中抠出相对应的洞 [0.4585, 0.456, 0.406]使用了imageNet的平均值
frames = frames * (1 - holes) + holes * torch.tensor([0.4585, 0.456, 0.406]).view(3, 1, 1, 1)
# valids area 验证标签,也就是获取图片中被扣掉部分原有的数值分布,用于后续的loss计算
valids = 1 - holes
# frames = frames.unsqueeze(0)
# holes = holes.unsqueeze(0)
# dists = dists.unsqueeze(0)
# valids = valids.unsqueeze(0)
return frames, valids, dists
class Generator(object):
def __init__(self, batch_size=1):
self.batch_size = batch_size
self.images = os.listdir("Image_inputs")
def generator(self):
# while True:
dir = self.images[random.choice(range(len(self.images)))]
frames = np.empty((T, H, W, 3), dtype=np.float32)
holes = np.empty((T, H, W, 1), dtype=np.float32)
dists = np.empty((T, H, W, 1), dtype=np.float32)
label = np.empty((T, H, W, 3), dtype=np.float32)
for i in range(5):
# rgb
img_file = os.path.join('Image_inputs', dir, 'gt_{:1d}.jpg'.format(i))
raw_frame = np.array(Image.open(img_file).convert('RGB')) / 255.
raw_frame = cv2.resize(raw_frame, dsize=(W, H), interpolation=cv2.INTER_CUBIC)
frames[i] = raw_frame
label[i] = raw_frame
# mask
mask_file = os.path.join('Image_inputs', dir, 'mask_{:1d}.png'.format(i))
raw_mask = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)
raw_mask = (raw_mask > 0.5).astype(np.uint8)
raw_mask = cv2.resize(raw_mask, dsize=(W, H), interpolation=cv2.INTER_NEAREST)
raw_mask = cv2.dilate(raw_mask, cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))) # cv2.dilate 膨胀操作
holes[i, :, :, 0] = raw_mask.astype(np.float32)
# dists
dists[i, :, :, 0] = cv2.distanceTransform(raw_mask, cv2.DIST_L2,
maskSize=5) # cv2.distanceTransform()可以方便地将前景对象提取出来
# 图片转换成tensor
frames = torch.from_numpy(np.transpose(frames, (3, 0, 1, 2)).copy()).float()
holes = torch.from_numpy(np.transpose(holes, (3, 0, 1, 2)).copy()).float()
dists = torch.from_numpy(np.transpose(dists, (3, 0, 1, 2)).copy()).float()
label = torch.from_numpy(np.transpose(label, (3, 0, 1, 2)).copy()).float()
# remove holes 在图片中抠出相对应的洞 [0.4585, 0.456, 0.406]使用了imageNet的平均值
frames = frames * (1 - holes) + holes * torch.tensor([0.4585, 0.456, 0.406]).view(3, 1, 1, 1)
# valids area 验证标签,也就是获取图片中被扣掉部分原有的数值分布,用于后续的loss计算
valids = 1 - holes
frames = frames.unsqueeze(0)
dists = dists.unsqueeze(0)
valids = valids.unsqueeze(0)
label = label.unsqueeze(0)
yield frames, valids, dists, label
class Dataset2(Dataset):
def __init__(self):
self.images = os.listdir("image")
def __len__(self):
return 1
def __getitem__(self, index):
dir = self.images[random.choice(range(len(self.images)))]
frames = np.empty((T, H, W, 3), dtype=np.float32)
holes = np.empty((T, H, W, 1), dtype=np.float32)
dists = np.empty((T, H, W, 1), dtype=np.float32)
label = np.empty((T, H, W, 3), dtype=np.float32)
for i in range(5):
# rgb
img_file = os.path.join('image', dir, 'gt_{:04d}.jpg'.format(random.choice(range(len(os.listdir(os.path.join('image', dir)))))))
raw_frame = np.array(Image.open(img_file).convert('RGB')) / 255.
raw_frame = cv2.resize(raw_frame, dsize=(W, H), interpolation=cv2.INTER_CUBIC)
frames[i] = raw_frame
label[i] = raw_frame
# mask
mask_file = os.path.join('mask', 'mask_{:04d}.png'.format(random.choice(range(0,52))))
raw_mask = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)
raw_mask = (raw_mask > 0.5).astype(np.uint8)
raw_mask = cv2.resize(raw_mask, dsize=(W, H), interpolation=cv2.INTER_NEAREST)
raw_mask = cv2.dilate(raw_mask, cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))) # cv2.dilate 膨胀操作
holes[i, :, :, 0] = raw_mask.astype(np.float32)
# dists
dists[i, :, :, 0] = cv2.distanceTransform(raw_mask, cv2.DIST_L2,
maskSize=5) # cv2.distanceTransform()可以方便地将前景对象提取出来
# 图片转换成tensor
frames = torch.from_numpy(np.transpose(frames, (3, 0, 1, 2)).copy()).float()
holes = torch.from_numpy(np.transpose(holes, (3, 0, 1, 2)).copy()).float()
dists = torch.from_numpy(np.transpose(dists, (3, 0, 1, 2)).copy()).float()
label = torch.from_numpy(np.transpose(label, (3, 0, 1, 2)).copy()).float()
# remove holes 在图片中抠出相对应的洞 [0.4585, 0.456, 0.406]使用了imageNet的平均值
frames = frames * (1 - holes) + holes * torch.tensor([0.4585, 0.456, 0.406]).view(3, 1, 1, 1)
# valids area 验证标签,也就是获取图片中被扣掉部分原有的数值分布,用于后续的loss计算
valids = 1 - holes
return frames, valids, dists, label
这里的data部分,因为需求不一样了,我略微的做了一点改变。用的是generator部分,前面几个data是失败部分。
三.common.py
import torch
import torch.nn as nn
def get_features(img, model, layers=None):
'''获取特征层'''
if layers is None:
layers = {
'0': 'conv1_1',
'5': 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'21': 'conv4_2', # content层
'28': 'conv5_1'
}
features = {}
x = img
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def gram_matrix(tensor):
'''计算Gram matrix'''
_, d, h, w = tensor.size() # 第一个是batch_size
tensor = tensor.view(d, h * w)
gram = torch.mm(tensor, tensor.t())
return gram
'''
TV loss是常用的一种正则项(注意是正则项,配合其他loss一起使用,约束噪声)
图片中相邻像素值的差异可以通过降低TV loss来一定程度上解决
图像上的一点点噪声可能就会对复原的结果产生非常大的影响,因为很多复原算法都会放大噪声。
这时候我们就需要在最优化问题的模型中添加一些正则项来
保持图像的光滑性
'''
class TVLoss(nn.Module):
def __init__(self, TVLoss_weight=1):
super(TVLoss, self).__init__()
self.TVLoss_weight = TVLoss_weight
def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self._tensor_size(x[:, :, 1:, :]) # 算出总共求了多少次差
count_w = self._tensor_size(x[:, :, :, 1:])
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
# x[:,:,1:,:]-x[:,:,:h_x-1,:]就是对原图进行错位,分成两张像素位置差1的图片,第一张图片
# 从像素点1开始(原图从0开始),到最后一个像素点,第二张图片从像素点0开始,到倒数第二个
# 像素点,这样就实现了对原图进行错位,分成两张图的操作,做差之后就是原图中每个像素点与相
# 邻的下一个像素点的差。
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
return self.TVLoss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
def _tensor_size(self, t):
return t.size()[1] * t.size()[2] * t.size()[3]
class L1_Loss(nn.Module):
def __init__(self):
super(L1_Loss, self).__init__()
def forward(self, x, y, pv):
loss = 0
pv = pv
if pv.ndim > 4:
pv = pv
pv = pv.squeeze(dim=0)
for i in range(pv.size(1)):
loss += torch.sum(torch.abs(x - y) * pv[:, i, :, :])
return loss
else:
loss = torch.sum(torch.abs(x - y) * pv)/(y.shape(0)*y.shape(1)*y.shape(2))
return loss
class L1_Lossv2(nn.Module):
def __init__(self):
super(L1_Lossv2, self).__init__()
def forward(self,x,y,pv):
loss=0
if pv.ndim>4:
for i in range(pv.size()[1]):
temp=pv[:,:,i]
loss=loss+torch.sum((torch.abs(x.flatten()-y.flatten())*temp.flatten()))
return loss
else:
loss=torch.sum(torch.abs(x.flatten()-y.flatten())*pv.flatten())
return loss
def L1(x,y,mask):
res=torch.abs(x-y)
res=res*mask
return torch.sum(res)/(y.shape(0)*y.shape(1)*y.shape(2))
def ll1(x,y):
return torch.sum(x-y)
这里主要是自己写loss,因为论文里他loss挺多的,相关博文介绍又少,我花了好多时间去测试,但自己写loss真的很重要。毕竟现在顶会论文花样越来越多了,已经不局限于网络结构设计都看不懂了,真不想搞了,cnn打天下不好吗
四.model.py
from __future__ import division
# general libs
import math
import sys
sys.path.insert(0, '.')
from .common import *
sys.path.insert(0, '../utils/')
from utils.helpers import *
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.conv12 = GatedConv2d(5, 64, kernel_size=5, stride=2, padding=2,
activation=nn.LeakyReLU(negative_slope=0.2)) # 2
self.conv2 = GatedConv2d(64, 64, kernel_size=3, stride=1, padding=1,
activation=nn.LeakyReLU(negative_slope=0.2)) # 2
self.conv23 = GatedConv2d(64, 128, kernel_size=3, stride=2, padding=1,
activation=nn.LeakyReLU(negative_slope=0.2)) # 4
self.conv3a = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=1,
activation=nn.LeakyReLU(negative_slope=0.2)) # 4
self.conv3b = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=2, dilation=2,
activation=nn.LeakyReLU(negative_slope=0.2)) # 4
self.conv3c = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=4, dilation=4,
activation=nn.LeakyReLU(negative_slope=0.2)) # 4
self.conv3d = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=8, dilation=8,
activation=nn.LeakyReLU(negative_slope=0.2)) # 4
self.key3 = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=None) # 4
self.val3 = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=None) # 4
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, in_f, in_v, in_h):
# 图片标准化 frames
f = (in_f - self.mean) / self.std
x = torch.cat([f, in_v, in_h], dim=1)
x = self.conv12(x)
x = self.conv2(x)
x = self.conv23(x)
x = self.conv3a(x)
x = self.conv3b(x)
x = self.conv3c(x)
x = self.conv3d(x)
k = self.key3(x)
v = self.val3(x)
return k, v
# 不对称注意力模块
class MaskedRead(nn.Module):
def __init__(self):
super(MaskedRead, self).__init__()
def forward(self, qkey, qval, qmask, mkey, mval, mmask):
'''
read for *mask area* of query from *mask area* of memory
'''
B, Dk, _, H, W = mkey.size()
_, Dv, _, _, _ = mval.size()
# key: b,dk,t,h,w
# value: b,dv,t,h,w
# mask: b,1,t,h,w
for b in range(B):
# exceptions
if qmask[b, 0].sum() == 0 or mmask[b, 0].sum() == 0:
# print('skipping read', qmask[b,0].sum(), mmask[b,0].sum())
# no query or mask pixels -> skip read
continue
# [128,284]
qk_b = qkey[b, :, qmask[b, 0]] # dk, Nq
mv_b = mval[b, :, mmask[b, 0]] # dv, Nm
mk_b = mkey[b, :, mmask[b, 0]] # dk, Nm #mkey(1,128,4,60,106) mmask(1,1,5,60,106)
# print(mv_b.shape)
p = torch.mm(torch.transpose(mk_b, 0, 1), qk_b) # Nm, Nq
p = p / math.sqrt(Dk) # 防止过拟合
p = torch.softmax(p, dim=0)
read = torch.mm(mv_b, p) # dv, Nq
# qval[b,:,qmask[b,0]] = read # dv, Nq
qval[b, :, qmask[b, 0]] = qval[b, :, qmask[b, 0]] + read # dv, Nq
return qval
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.conv3d = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=8, dilation=8,
activation=nn.LeakyReLU(negative_slope=0.2)) # 4
self.conv3c = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=4, dilation=4,
activation=nn.LeakyReLU(negative_slope=0.2)) # 4
self.conv3b = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=2, dilation=2,
activation=nn.LeakyReLU(negative_slope=0.2)) # 4
self.conv3a = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=1,
activation=nn.LeakyReLU(negative_slope=0.2)) # 4
self.conv32 = GatedConv2d(128, 64, kernel_size=3, stride=1, padding=1,
activation=nn.LeakyReLU(negative_slope=0.2)) # 2
self.conv2 = GatedConv2d(64, 64, kernel_size=3, stride=1, padding=1,
activation=nn.LeakyReLU(negative_slope=0.2)) # 2
self.conv21 = GatedConv2d(64, 3, kernel_size=5, stride=1, padding=2, activation=None) # 1
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, x):
x = self.conv3d(x)
x = self.conv3c(x)
x = self.conv3b(x)
x = self.conv3a(x)
x = F.interpolate(x, scale_factor=2, mode='nearest') # 2
x = self.conv32(x)
x = self.conv2(x)
x = F.interpolate(x, scale_factor=2, mode='nearest') # 2
x = self.conv21(x)
p = (x * self.std) + self.mean
return p
class OPN(nn.Module):
def __init__(self, mode='Train', CPU_memory=False, thickness=8):
super(OPN, self).__init__()
self.Encoder = Encoder()
self.MaskedRead = MaskedRead()
self.Decoder = Decoder()
self.thickness = thickness
self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('mean3d', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1, 1))
def memorize(self, frames, valids, dists):
'''
encode every frame of *valid* area into key:value
Done once as initialization
'''
# padding到同一尺寸
(frames, valids, dists), pad = pad_divide_by([frames, valids, dists], 4, (frames.size()[3], frames.size()[4]))
# make hole
holes = (dists > 0).float()
frames = (1 - holes) * frames + holes * self.mean3d
batch_size, _, num_frames, height, width = frames.size()
# num_frames 图片的张数 :5张
# encoding...
key_ = []
val_ = []
for t in range(num_frames):
key, val = self.Encoder(frames[:, :, t], valids[:, :, t], holes[:, :, t])
key_.append(key)
val_.append(val)
keys = torch.stack(key_, dim=2)
vals = torch.stack(val_, dim=2)
hols = (F_upsample3d(holes, size=(int(height / 4), int(width / 4)), mode='bilinear', align_corners=False) > 0)
return keys, vals, hols
def read(self, mkey, mval, mhol, frame, valid, dist):
'''
## assume single frame query
1) encode current status of frames -> query
2) read from memmories (computed calling 'memorize')
3) decode readed feature
4) compute loss on peel area
'''
thickness = self.thickness
# padding
(frame, valid, dist), pad = pad_divide_by([frame, valid, dist], 4, (frame.size()[2], frame.size()[3]))
batch_size, _, height, width = frame.size()
# make hole and peel..
hole = (dist > 0).float()
peel = hole * (dist <= thickness).float()
# 更新dist
next_dist = torch.clamp(dist - thickness, 0, 9999)
# get 1/4 scale mask
peel3 = (F.upsample(peel, size=(int(height / 4), int(width / 4)), mode='bilinear', align_corners=False) >= 0.5)
# 更新frame
frame = (1 - hole) * frame + hole * self.mean
# reading and decoding...
qkey, qval = self.Encoder(frame, valid, hole)
qpel = peel3
# read 不对称注意力块.
read = self.MaskedRead(qkey, qval, qpel, mkey, mval, ~mhol)
# decode
pred = self.Decoder(read)
comp = (1 - peel) * frame + peel * pred # fill peel area
if pad[2] + pad[3] > 0:
comp = comp[:, :, pad[2]:-pad[3], :]
next_dist = next_dist[:, :, pad[2]:-pad[3], :]
if pad[0] + pad[1] > 0:
comp = comp[:, :, :, pad[0]:-pad[1]]
next_dist = next_dist[:, :, :, pad[0]:-pad[1]]
# 防止颜色通道信息的溢出
comp = torch.clamp(comp, 0, 1)
return comp, next_dist, peel
def forward(self, *args, **kwargs):
# print(len(args), len(kwargs))
if len(args) == 3:
return self.memorize(*args)
else:
return self.read(*args, **kwargs)
这里是论文中的相关部分,说实话这技术太厉害了,一般人根本写不出来,全是抄的就不说了
五.model_common.py
from __future__ import division
# general libs
import sys
import torch.nn.functional as F
sys.path.insert(0, '../utils/')
from utils.helpers import *
##########################################
############ Generic #################
##########################################
def pad_divide_by(in_list, d, in_size):
out_list = []
h, w = in_size
if h % d > 0:
new_h = h + d - h % d
else:
new_h = h
if w % d > 0:
new_w = w + d - w % d
else:
new_w = w
lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
pad_array = (int(lw), int(uw), int(lh), int(uh))
for inp in in_list:
out_list.append(F.pad(inp, pad_array))
return out_list, pad_array
class ConvGRU(nn.Module):
def __init__(self, mdim, kernel_size=3, padding=1):
super(ConvGRU, self).__init__()
self.convIH = nn.Conv2d(mdim, 3 * mdim, kernel_size=kernel_size, padding=padding)
self.convHH = nn.Conv2d(mdim, 3 * mdim, kernel_size=kernel_size, padding=padding)
def forward(self, input, hidden_tm1):
if hidden_tm1 is None:
hidden_tm1 = torch.zeros_like(input)
gi = self.convIH(input)
gh = self.convHH(hidden_tm1)
i_r, i_i, i_n = torch.chunk(gi, 3, dim=1)
h_r, h_i, h_n = torch.chunk(gh, 3, dim=1)
resetgate = torch.sigmoid(i_r + h_r) # reset
inputgate = torch.sigmoid(i_i + h_i) # update
newgate = F.tanh(i_n + resetgate * h_n)
# hidden_t = inputgate * hidden_tm1 + (1-inputgate)*newgate
hidden_t = newgate + inputgate * (hidden_tm1 - newgate)
return hidden_t
def F_upsample3d(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
num_frames = x.size()[2]
up_s = []
for f in range(num_frames):
up = F.interpolate(x[:, :, f], size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
up_s.append(up)
ups = torch.stack(up_s, dim=2)
return ups
def F_upsample(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
if x.dim() == 5: # 3d
return F_upsample3d(x, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
else:
return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
class GatedConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, activation=None):
super().__init__()
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias)
self.gating_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias)
init_He(self)
self.activation = activation
def forward(self, input):
# O = act(Feature) * sig(Gating)
feature = self.input_conv(input)
if self.activation:
feature = self.activation(feature)
gating = torch.sigmoid(self.gating_conv(input))
return feature * gating
这里是模型的相关配置文件,也是抄的
六.train.py
import os
import sys
from tqdm import tqdm
import torch.backends.cudnn as cudnn
import torch.optim as optim
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import models
from common import *
from datalist import Dataset2
from models.OPN import OPN
from utils.helpers import *
sys.path.append('utils/')
sys.path.append('models/')
style_weights = {
'conv1_1': 1,
'conv2_1': 0.8,
'conv3_1': 0.5,
'conv4_1': 0.3,
'conv5_1': 0.1,
}
from config import parser
class train(object):
def __init__(self):
self.args = parser.parse_args()
print(f"-----------{self.args.project_name}-----------")
use_cuda = self.args.use_cuda and torch.cuda.is_available()
if use_cuda:
torch.cuda.manual_seed(self.args.seed)
else:
torch.manual_seed(self.args.seed)
self.device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}
'''
构造DataLoader
'''
# ToDo 数据集需要重新制备
print("Create Dataloader")
self.train_loader = DataLoader(Dataset2(), batch_size=1, shuffle=True, **kwargs)
self.test_loader = DataLoader(Dataset2(), batch_size=1, shuffle=True, **kwargs)
'''
定义模型
'''
print("Create Model")
self.model = OPN().to(self.device)
# self.model = nn.DataParallel(OPN())
if use_cuda:
# self.model = self.model.cuda()
cudnn.benchmark = True
'''
根据需要加载预训练的模型权重参数
'''
# VGG16模型配合预训练的模型用于检测
self.vgg = models.vgg16(pretrained=True).to(self.device).features
for i in self.vgg.parameters():
i.requires_grad = False
try:
if self.args.resume and self.args.pretrained_weight:
self.model.load_state_dict(torch.load(os.path.join('OPN.pth')), strict=False)
print("模型加载成功")
except:
print("模型加载失败")
'''
cuda加速
'''
if use_cuda:
# self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
cudnn.benchmark = True
'''
构造loss目标函数
选择优化器
学习率变化选择
'''
print("Establish the loss, optimizer and learning_rate function")
self.loss_tv = TVLoss()
self.loss_l1=L1_Loss()
# 另外还有style—loss 和 content—loss
# self.optimizer = optim.SGD(
# params=self.model.parameters(),
# lr=self.args.lr,
# weight_decay=self.args.weight_decay,
# momentum=0.5
# )
self.optimizer = optim.Adam(
params=self.model.parameters(),
lr=0.001,
betas=(0.9, 0.999),
eps=1e-8, # 为了防止分母为0
weight_decay=0
)
# self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=5, eta_min=1e-5)
'''
模型开始训练
'''
print("Start training")
for epoch in tqdm(range(1, self.args.epoch + 1)):
self.train(epoch)
if epoch % 20==0:
self.test(epoch)
torch.cuda.empty_cache()
print("finish model training")
def train(self, epoch):
self.model.train()
for data in self.train_loader:
self.content_loss = 0
self.style_loss = 0
midx = list(range(0, 5))
# frames被破损的图像,valids可获取的像素区域,dists填补的像素区域
frames, valids, dists, label = data
frames, valids, dists, label = frames.to(self.device), valids.to(self.device), dists.to(
self.device), label.to(self.device)
# 每一张图片都被encoder过了获得的key和val shape为(1,128,5,60,106),hol为(1,1,5,60,106)
mkey, mval, mhol = self.model(frames[:, :, midx], valids[:, :, midx], dists[:, :, midx])
allloss=0
for f in range(5):
loss=0
# 对每张图取其他4张图作为reference的参考
ridx = [i for i in range(len(midx)) if i != f]
fkey, fval, fhol = mkey[:, :, ridx], mval[:, :, ridx], mhol[:, :, ridx]
# 图像补全
for r in range(5):
if r == 0:
# 取主图
comp = frames[:, :, f]
dist = dists[:, :, f]
# comp是破损的图片,逐层补全图片
# valids是没有缺失信息的区域
# dist是缺失信息的区域
'''
按dist的指导,逐8个像素的距离,循环修复图片,其中valids表示空洞部分的区域(0,1)
comp是在frame的基础之上补充的,相似度极高,只计算这一部分的loss
'''
comp, dist, peel = self.model(fkey, fval, fhol, comp, valids[:, :, f], dist)
# 每次循环中分别在像素空间和深层特征空间最小化和GT的L1距离。
loss += 100 * L1(comp, label[:, :, f], peel)
# loss += L1(comp, label[:, :, f], valids[:,:,f])
loss+=0.2*self.loss_l1(comp,label[:,:,f],valids[:,:,midx])
# loss+=100*ll1(comp,frames[:,:,f])
# content loss
content_features = get_features(frames[:, :, f], self.vgg)
target_features = get_features(comp, self.vgg)
self.content_loss = torch.mean(
torch.abs((target_features['conv4_2'] - content_features['conv4_2'])))
loss = loss + 0.05 * self.content_loss
# style loss
style_features = get_features(comp, self.vgg)
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
'''加上每一层的gram_matrix矩阵的损失'''
for layer in style_weights:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
_, d, h, w = target_feature.shape
style_gram = style_grams[layer]
layer_style_loss = style_weights[layer] * torch.mean(torch.abs((target_gram - style_gram)))
self.style_loss += layer_style_loss / (d * h * w) # 加到
loss = loss + 120 * self.style_loss
# tv loss
loss += 0.01 * self.loss_tv(comp)
allloss+=loss
self.optimizer.zero_grad()
allloss.backward()
self.optimizer.step()
# self.scheduler.step()
# print("epoch{}".format(epoch) + " loss:{}".format(loss.cpu()))
def test(self, epoch):
self.model.eval()
for frames, valids, dists, _ in self.test_loader:
midx = list(range(0, 5))
# frames, valids, dists = data
frames, valids, dists = frames.to(self.device), valids.to(self.device), dists.to(self.device)
with torch.no_grad():
# 先把这5张图片都encoder一下
mkey, mval, mhol = self.model(frames[:, :, midx], valids[:, :, midx], dists[:, :, midx])
# 对每张图取其他4张图作为reference的参考
for f in range(5):
ridx = [i for i in range(len(midx)) if i != f]
fkey, fval, fhol = mkey[:, :, ridx], mval[:, :, ridx], mhol[:, :, ridx]
# 图像补全
for r in range(999):
if r == 0:
comp = frames[:, :, f]
dist = dists[:, :, f]
with torch.no_grad():
comp, dist,peel = self.model(fkey, fval, fhol, comp, valids[:, :, f], dist)
comp, dist = comp.detach(), dist.detach()
# 空隙填满进入后,把图片保存,然后进入下一轮图片的计算过程中
if torch.sum(dist).item() == 0:
break
if self.args.save:
# visualize..
est = (comp[0].permute(1, 2, 0).detach().cpu().numpy() * 255.).astype(np.uint8)
true = (frames[0, :, f].permute(1, 2, 0).detach().cpu().numpy() * 255.).astype(np.uint8) # h,w,3
mask = (dists[0, 0, f].detach().cpu().numpy() > 0).astype(np.uint8) # h,w,1
ov_true = overlay_davis(true, mask, colors=[[0, 0, 0], [100, 100, 0]], cscale=2, alpha=0.4)
canvas = np.concatenate([ov_true, est], axis=0)
save_path = os.path.join('Results')
if not os.path.exists(save_path):
os.makedirs(save_path)
canvas = Image.fromarray(canvas)
canvas.save(os.path.join(save_path, 'res_{}_{}.jpg'.format(epoch, f)))
# print("epoch{}".format(epoch) + " test finished")
if __name__ == "__main__":
train()
这里就是我自己写的train,主要还是借鉴了demo的相关内容,我自己测试了一下训练出来的图片虽然效果不佳,但是层次感和色块的分布到有一点感觉了。
毕竟跟作者在论文中提到的一样,v100这种设备普通人家都没有的,我现在感觉2080都是垃圾。当然我能运行,是因为我有机会能接触到设备呗。
总结
上面的图片我只跑了半个小时,论文中用到了learning-scheduler,batch—size=5,并且跑了4天。从上面的途中可以看出模型能大概估计出损失部分的轮廓应该长什么样子.我觉得需要继续深入,复现应该可以做到