from os import listdir
from os.path import join
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image, ImageFilter
import numpy as np
import torch
import cv2
import os
import random
import torch.nn.functional as F
from torchvision.transforms import functional as FF
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def load_img(lr_path, hr_path):
# lr = cv2.imread(lr_path)
lr = Image.open(lr_path).convert('L')
hr = Image.open(hr_path).convert('L')
# y_lr, _, _ = lr.split()
# y_hr, _, _ = hr.split()
return lr, hr
def random_rot(images):
randint = random.randint(0, 4)
if randint == 0:
for i in range(len(images)):
images[i] = cv2.rotate(images[i], cv2.ROTATE_90_CLOCKWISE)
elif randint == 1:
for i in range(len(images)):
images[i] = cv2.rotate(images[i], cv2.ROTATE_180)
elif randint == 2:
for i in range(len(images)):
images[i] = cv2.rotate(images[i], cv2.ROTATE_90_COUNTERCLOCKWISE)
else:
pass
return images
def random_flip(images):
if random.random() < 0.5:
for i in range(len(images)):
images[i] = cv2.flip(images[i], 1)
if random.random() < 0.5:
for i in range(len(images)):
images[i] = cv2.flip(images[i], 0)
return images
def random_crop1(images):
h, w = images[1].shape[:2]
crops = []
if split == 'test':
new_h = 256
new_w = 256
# for image in images:
# imaget = image[0:0 + new_h, 0:0 + new_w]
# crops.append(imaget)
lr = images[0][0:0 + (new_h // 2), 0:0 + (new_w // 2)]
hr = images[1][0:0 + (new_h), 0:0 + (new_w)]
crops.append(lr)
crops.append(hr)
else:
new_h = 1200
new_w = 1200
y = np.random.randint(0, h-new_h) # 随机整型数[0,h-new_h)
x = np.random.randint(0, w-new_w)
lr = images[0][y//2:y//2+(new_h//2), x//2:x//2+(new_w//2)]
hr = images[1][y:y + (new_h), x:x + (new_w)]
crops.append(lr)
crops.append(hr)
# for image in images:
# imaget = image[y:y+new_h, x:x+new_w]
# crops.append(imaget)
return crops
def random_crop2(images, sizeTo=32):
w = images[0].shape[1]
h = images[0].shape[0]
w_offset = random.randint(0, max(0, w - sizeTo - 1))
h_offset = random.randint(0, max(0, h - sizeTo - 1))
for i in range(len(images)):
images[i] = images[i][h_offset:h_offset + sizeTo, w_offset:w_offset + sizeTo]
return images
crop_size = 32
class DatasetFromFolder(Dataset):
def __init__(self, hr_path, lr_path, train):
super(DatasetFromFolder, self).__init__()
self.hr_path = hr_path
self.lr_path = lr_path
self.train = train
# self.hr_filenames = [(os.listdir(hr_path)).sort(key=lambda x:int(x[:-4]))]
self.hr_filenames = os.listdir(hr_path)
# self.hr_filenames = [join(hr_path, x) for x in listdir(hr_path) if is_image_file(x)]
# self.input_transform = transforms.Compose([transforms.Resize(zoom_factor, interpolation=Image.BICUBIC)])
# self.input_transform = transforms.Compose([transforms.CenterCrop(crop_size), # cropping the image
# transforms.ToTensor()])
# self.target_transform = transforms.Compose([transforms.CenterCrop(crop_size), # since it's the target, we keep its original quality
# transforms.ToTensor()])
def __getitem__(self, index):
# print(self.hr_filenames)
# self.hr_filenames.sort(key=lambda x:int(x[:-4]))
hr_filename = self.hr_path + '/' + self.hr_filenames[index]
lr_filename = self.lr_path + '/' + self.hr_filenames[index]
# lr_filename = str(os.path.dirname(self.hr_filenames[index]) +'/' + str(os.path.basename(self.hr_filenames[index]).split('.'))[0] + "x" + self.zoom_factor + ".png")
input, target = load_img(lr_filename, hr_filename)
# 白平衡
# img3 = self.white_balance_3(input)
# cv2.imwrite('/export/liuzhe/program/SRCNN_en/1111111.png', img3)
# exit(-1)
input = np.array(input, dtype=np.float64) / 255
target = np.array(target, dtype=np.float64) /255
# input = np.asarray(input).astype('float64') / 255
# target = np.asarray(target).astype('float64') /255
# print(haze)
#
# print(np.max(haze))
# print(np.min(haze))
# print(np.mean(haze))
# exit(-1)
images = [input, target]
images = random_crop2(images, 32)
images = random_rot(images)
images = random_flip(images)
[input, target] = images
input = torch.from_numpy(input).float()
target = torch.from_numpy(target).float()
input = input.unsqueeze(0)
target = target.unsqueeze(0)
# if self.train:
# rand_hor = random.randint(0, 1)
# rand_rot = random.randint(0, 3)
# input = transforms.RandomHorizontalFlip(rand_hor)(input)
# target = transforms.RandomHorizontalFlip(rand_hor)(target)
# if rand_rot:
# input = FF.rotate(input, 90 * rand_rot)
# target = FF.rotate(target, 90 * rand_rot)
# input = self.input_transform(input)
# target = self.target_transform(target)
# input = transforms.ToTensor()(input)
# target = transforms.ToTensor()(target)
# print(input.shape)
# exit(-1)
return input, target
def __len__(self):
return len(self.hr_filenames)
def white_balance_3(self, img):
'''
灰度世界假设
:param img: cv2.imread读取的图片数据
:return: 返回的白平衡结果图片数据
'''
B, G, R = np.double(img[:, :, 0]), np.double(img[:, :, 1]), np.double(img[:, :, 2])
B_ave, G_ave, R_ave = np.mean(B), np.mean(G), np.mean(R)
K = (B_ave + G_ave + R_ave) / 3
Kb, Kg, Kr = K / B_ave, K / G_ave, K / R_ave
Ba = (B * Kb)
Ga = (G * Kg)
Ra = (R * Kr)
for i in range(len(Ba)):
for j in range(len(Ba[0])):
Ba[i][j] = 255 if Ba[i][j] > 255 else Ba[i][j]
Ga[i][j] = 255 if Ga[i][j] > 255 else Ga[i][j]
Ra[i][j] = 255 if Ra[i][j] > 255 else Ra[i][j]
# print(np.mean(Ba), np.mean(Ga), np.mean(Ra))
dst_img = np.uint8(np.zeros_like(img))
dst_img[:, :, 0] = Ba
dst_img[:, :, 1] = Ga
dst_img[:, :, 2] = Ra
return dst_img