import torch.utils.data as data
from torchvision.transforms import *
from os import listdir
from os.path import join
from PIL import Image
import random
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg", ".bmp"])
def load_img(filepath):
img = Image.open(filepath).convert('RGB')
return img
def calculate_valid_crop_size(crop_size, scale_factor):
return crop_size - (crop_size % scale_factor)
class TrainDatasetFromFolder(data.Dataset):
def __init__(self, image_dirs, is_gray=False, random_scale=True, crop_size=128, rotate=True, fliplr=True,
fliptb=True, scale_factor=4):
super(TrainDatasetFromFolder, self).__init__()
self.image_filenames = []
for image_dir in image_dirs:
self.image_filenames.extend(join(image_dir, x) for x in so