阿里巴巴优酷视频增强和超分辨率挑战赛记录

20 篇文章 5 订阅
13 篇文章 0 订阅

之前做过超分辨率,刚好有这个比赛,拿来记录一下,截止目前初赛,score=40.22,排名46。
更新,已经复赛B轮了,目前排名24。

数据

官方给的txt,用于获取:

y4m 格式介绍:https://wiki.multimedia.cx/index.php/YUV4MPEG2
y4m 与 yuv(yuv420 8bit planar) 互转命令:
    y4mtoyuv: ffmpeg -i xx.y4m -vsync 0 xx.yuv  -y
    yuvtoy4m: ffmpeg -s 1920x1080 -i xx.yuv -vsync 0 xx.y4m -y
y4m 与 png 互转命令:
   y4mtobmp: ffmpeg -i xx.y4m -vsync 0 xx%3d.bmp -y
   bmptoy4m: ffmpeg -i xx%3d.bmp  -pix_fmt yuv420p  -vsync 0 xx.y4m -y
y4m 每25帧抽样命令:
   ffmpeg -i xxx.y4m -vf select='not(mod(n\,25))' -vsync 0  -y xxx_sub25.y4m

## 初赛训练数据下载链接
round1_train_input:
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/input/youku_00000_00049_l.zip
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/input/youku_00050_00099_l.zip
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/input/youku_00100_00149_l.zip

round1_train_label:
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/label/youku_00000_00049_h_GT.zip
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/label/youku_00050_00099_h_GT.zip
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/label/youku_00100_00149_h_GT.zip

## 初赛验证数据下载链接
round1_val_input:
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/input/youku_00150_00199_l.zip

round1_val_label:
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/train/label/youku_00150_00199_h_GT.zip

## 初赛测试数据下载链接
round1_test_input:
http://tianchi-media.oss-cn-beijing.aliyuncs.com/231711_youku/round1/test/input/youku_00200_00249_l.zip

自己写了几个python脚本批量处理数据:
1.每个y4m抽取100张图片:

import os
import time
lst = os.listdir("./y4m_")  
for c in lst:
    if c.endswith('.y4m'):  
        print(c)  
        tmp = "ffmpeg -i "+"./y4m_/"+c+" -q:v 2 -vsync 0 ./image_bmp/"+c[:-4]+"%3d.bmp -y"
        os.system(tmp) 
        time.sleep(5)

2.每个y4m抽取100张图片

import os
import time
lst = os.listdir("./y4m")  
for c in lst:
    if c.endswith('.y4m'):  
            tmp = "ffmpeg -i "+"./y4m/"+c+" -vf "+"\"select=not(mod(n\,25))\" "+" -vsync 0 ./image_bmp/"+c[:-4]+"%3d.bmp -y"
            os.system(tmp) 
            time.sleep(3)

3.将图片转为所需的.y4m,图片放在image_x4下,下一级目录为要转为的.y4m名称,对应放着路径

import os 
import time
for dir_ in os.listdir("./image_x4"):
    tmp = "ffmpeg -i "+ "./image_x4/"+dir_+"/%3d.bmp "+"-pix_fmt yuv420p -vsync 0 "+"./result/"+dir_+".y4m -y"
    os.system(tmp) 
    time.sleep(5)

4.批量改名

import os

for name in os.listdir("./result"):
    oldname = "./result/"+name
    newname ="./result/"+name[:-4]+"_h_Sub25_Res.y4m"
    os.rename(oldname,newname)
    print(newname)

5.先用插值算法测试:

import cv2
import os

for name in os.listdir("image"):
    if not os.path.exists("./image_x4/"+name[:11]):
        os.makedirs("./image_x4/"+name[:11]) 
        image=cv2.imread("./image/"+name)
        res=cv2.resize(image,(image.shape[1]*4,image.shape[0]*4),interpolation=cv2.INTER_LANCZOS4)
        cv2.imwrite("./image_x4/"+name[:11]+"/"+name[-7:],res)
        print(name)
    else:
        image=cv2.imread("./image/"+name)
        res=cv2.resize(image,(image.shape[1]*4,image.shape[0]*4),interpolation=cv2.INTER_LANCZOS4)
        cv2.imwrite("./image_x4/"+name[:11]+"/"+name[-7:],res)
        print(name)

比赛发现的比较晚,先用插值试一下提交模型的步骤,成绩:
在这里插入图片描述
6.测试代码,在ESRGAN算法进行改进,损失只用MSE:

import sys
import os.path
import glob
import cv2
import numpy as np
import torch
import architecture as arch

model_path = sys.argv[1]  
device = torch.device('cuda')  # if you want to run on CPU, change 'cuda' -> cpu

model = arch.RRDB_Net(3, 3, 64, 26, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
                        mode='CNA', upsample_mode='upconv')
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
    v.requires_grad = False
model = model.to(device)

print('Model path {:s}. \nTesting...'.format(model_path))

for name in os.listdir("image_bmp"):
    if not os.path.exists("./image_x4/"+name[:11]):
        os.makedirs("./image_x4/"+name[:11]) 
        img = cv2.imread("./image_bmp/"+name)
        img = img * 1.0 / 255
        img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
        img_LR = img.unsqueeze(0)
        img_LR = img_LR.to(device)

        output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
        output = (output * 255.0).round()
        cv2.imwrite("./image_x4/"+name[:11]+"/"+name[-7:],output)
        print(name)
    else:
        img = cv2.imread("./image_bmp/"+name)
        img = img * 1.0 / 255
        img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
        img_LR = img.unsqueeze(0)
        img_LR = img_LR.to(device)

        output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
        output = (output * 255.0).round()
        cv2.imwrite("./image_x4/"+name[:11]+"/"+name[-7:],output)
        print(name)

目前的成绩,暂时只用了2000 * 2张图片训练
在这里插入图片描述

数据增强

90、180、270度旋转

import scipy 
from scipy import misc 
import os 
import time 
import glob 
from scipy import ndimage 

def get_image_paths(folder): 
    return glob.glob(os.path.join(folder, '*.png')) 

def create_read_img(filename): 
    im = misc.imread(filename) 
    img_rote_90 = ndimage.rotate(im, 90)  
    scipy.misc.imsave(filename[:-4]+'_90.png',img_rote_90) 
    
    img_rote_180 = ndimage.rotate(im, 180) 
    scipy.misc.imsave(filename[:-4]+'_180.png',img_rote_180) 

    img_rote_270 = ndimage.rotate(im, 270) 
    scipy.misc.imsave(filename[:-4]+'_270.png',img_rote_270) 
    print(filename)
img_path = '/media/wxy/000F8E4B0002F751/test/' 
imgs = get_image_paths(img_path) 
#print (imgs) 

for i in imgs: 
    create_read_img(i)

镜像翻转
根据原始图像名称进行翻转

import cv2
import os

for name in os.listdir("./HR_image/"):
    if len(name)==23:
        image = cv2.imread("./HR_image/"+name)
        h_flip = cv2.flip(image, 1) #左右
        cv2.imwrite("./HR_image/"+name[:-4]+"_flip_h.png", h_flip)
        w_flip = cv2.flip(image, 0) #上下
        cv2.imwrite("./HR_image/"+name[:-4]+"_flip_w.png", w_flip)
        print(name)

同时增强

from PIL import Image
import os 
import glob 

def get_image_paths(folder): 
    return glob.glob(os.path.join(folder, '*.png')) 

def create_read_img(filename): 
    #读取图像
    im = Image.open(filename)

    out_h = im.transpose(Image.FLIP_LEFT_RIGHT)
    out_w = im.transpose(Image.FLIP_TOP_BOTTOM)
    out_90 = im.transpose(Image.ROTATE_90)
    out_180 = im.transpose(Image.ROTATE_180)
    out_270 = im.transpose(Image.ROTATE_270)
    
    out_h.save(filename[:-4]+'_h.png')
    out_w.save(filename[:-4]+'_w.png')
    out_90.save(filename[:-4]+'_90.png')
    out_180.save(filename[:-4]+'_180.png')
    out_270.save(filename[:-4]+'_270.png')
    print(filename)
    
img_path = '/media/wxy/000F8E4B0002F751/test/' 
imgs = get_image_paths(img_path) 

for i in imgs: 
    create_read_img(i)

多线程图像增强

import time
import threadpool
import os
from PIL import Image

name = ["/media/wxy/000F8E4B0002F751/test/"+name_ for name_ in os.listdir("./test")]

def create_read_img(filename):
    # 读取图像
    im = Image.open(filename)
    out_h = im.transpose(Image.FLIP_LEFT_RIGHT)
    out_w = im.transpose(Image.FLIP_TOP_BOTTOM)
    out_90 = im.transpose(Image.ROTATE_90)
    out_180 = im.transpose(Image.ROTATE_180)
    out_270 = im.transpose(Image.ROTATE_270)

    out_h.save(filename[:-4] + '_h.png')
    out_w.save(filename[:-4] + '_w.png')
    out_90.save(filename[:-4] + '_90.png')
    out_180.save(filename[:-4] + '_180.png')
    out_270.save(filename[:-4] + '_270.png')
    print(filename)

start_time = time.time()
pool = threadpool.ThreadPool(5)
requests = threadpool.makeRequests(create_read_img, name)
[pool.putRequest(req) for req in requests]
pool.wait()
print ('%d second'% (time.time()-start_time))

PSNR、SSIM测试代码

import cv2
import numpy as np
import math

def bgr2ycbcr(img, only_y=True):
    '''bgr version of rgb2ycbcr
    only_y: only return Y channel
    Input:
        uint8, [0, 255]
        float, [0, 1]
    '''
    in_img_type = img.dtype
    img.astype(np.float32)
    if in_img_type != np.uint8:
        img *= 255.
    # convert
    if only_y:
        rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
    else:
        rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
                              [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
    if in_img_type == np.uint8:
        rlt = rlt.round()
    else:
        rlt /= 255.
    return rlt.astype(in_img_type)

def calculate_psnr(img1, img2):
    # img1 and img2 have range [0, 255]
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

def calculate_ssim(img1, img2):
    '''calculate SSIM
    the same outputs as MATLAB's
    img1, img2: [0, 255]
    '''
    if not img1.shape == img2.shape:
        raise ValueError('Input images must have the same dimensions.')
    if img1.ndim == 2:
        return ssim(img1, img2)
    elif img1.ndim == 3:
        if img1.shape[2] == 3:
            ssims = []
            for i in range(3):
                ssims.append(ssim(img1, img2))
            return np.array(ssims).mean()
        elif img1.shape[2] == 1:
            return ssim(np.squeeze(img1), np.squeeze(img2))
    else:
        raise ValueError('Wrong input image dimensions.')

def ssim(img1, img2):
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


def main():
    gt_img=cv2.imread("Youku_00150_h_GT001.png")
    sr_img=cv2.imread("Youku_00150_l001_120000.png")

    img2gray = cv2.cvtColor(sr_img, cv2.COLOR_BGR2GRAY)
    img2gray_ = cv2.cvtColor(gt_img, cv2.COLOR_BGR2GRAY)

    gt_img=gt_img/255
    sr_img=sr_img/255

    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
    gt_img_y = bgr2ycbcr(gt_img, only_y=True)

    cropped_sr_img_y = sr_img_y[4:-4, 4:-4]
    cropped_gt_img_y = gt_img_y[4:-4, 4:-4]

    psnr_y = calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)
    ssim_y = calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)

    print(psnr_y,ssim_y)

    

if __name__ == '__main__':
    main()
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值