目前觉得图像预处理是一件挺难的事情,有蛮多复杂的事情。
今天还是分析DnCNN两个版本的代码
https://github.com/SaoYan/DnCNN-PyTorch/blob/master/dataset.py
def Im2Patch(img, win, stride=1):
k = 0
endc = img.shape[0]
endw = img.shape[1]
endh = img.shape[2]
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
TotalPatNum = patch.shape[1] * patch.shape[2]
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
for i in range(win):
for j in range(win):
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
k = k + 1
return Y.reshape([endc, win, win, TotalPatNum])
顾名思义,将一整张图片划分成一个个小的patch。img.shape通道数,宽、高。老实说,这段代码,我不是很能理解。。。希望有大神带带我
def gen_patches(file_name):
# get multiscale patches from a single image
img = cv2.imread(file_name, 0) # gray scale
h, w = img.shape
patches = []
for s in scales:
h_scaled, w_scaled = int(h*s), int(w*s)
img_scaled = cv2.resize(img, (h_scaled, w_scaled), interpolation=cv2.INTER_CUBIC)
# extract patches
for i in range(0, h_scaled-patch_size+1, stride):
for j in range(0, w_scaled-patch_size+1, stride):
x = img_scaled[i:i+patch_size, j:j+patch_size]
for k in range(0, aug_times):
x_aug = data_aug(x, mode=np.random.randint(0, 8))
patches.append(x_aug)
return patches
这个提取patch我能理解 i是每个patch开始的地方,每隔步长有一个 i+patch_size提取的部分。
下一部分是准备数据,代码如下,版本一:
def prepare_data(data_path, patch_size, stride, aug_times=1):
# train
print('process training data')
scales = [1, 0.9, 0.8, 0.7]
files = glob.glob(os.path.join(data_path, 'train', '*.png'))
files.sort()
h5f = h5py.File('train.h5', 'w')
train_num = 0
for i in range(len(files)):
img = cv2.imread(files[i])
h, w, c = img.shape
for k in range(len(scales)):
Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)
Img = np.expand_dims(Img[:,:,0].copy(), 0)
Img = np.float32(normalize(Img))
patches = Im2Patch(Img, win=patch_size, stride=stride)
print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], patches.shape[3]*aug_times))
for n in range(patches.shape[3]):
data = patches[:,:,:,n].copy()
h5f.create_dataset(str(train_num), data=data)
train_num += 1
for m in range(aug_times-1):
data_aug = data_augmentation(data, np.random.randint(1,8))
h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug)
train_num += 1
h5f.close()
后面还有一部分是验证集的我就不粘过来了。
scales = [1, 0.9, 0.8, 0.7] 缩放的尺度,为了有更多的数据。
files = glob.glob(os.path.join(data_path, 'train', '*.png')) 这句话还挺重要的:看教程https://blog.csdn.net/zmdzbzbhss123/article/details/52279008
通过cv2.resize实现缩放功能,expand_dims增加一个维度,具体见教程 https://www.jianshu.com/p/da10840660cb 但是不明白为什么要这样做,但是在所有关于数据集的预处理上都会有这种操作。
normalize解释https://blog.csdn.net/weicao1990/article/details/54584788?utm_source=blogxgwz6
关于h5的知识一直是空白的,但是一般图像处理完了都会保存在这里面,现在恶补一下相关知识:http://docs.h5py.org/en/stable/quick.html#quick
https://blog.csdn.net/y_f_raquelle/article/details/84134851
class Dataset(udata.Dataset):
def __init__(self, train=True):
super(Dataset, self).__init__()
self.train = train
if self.train:
h5f = h5py.File('train.h5', 'r')
else:
h5f = h5py.File('val.h5', 'r')
self.keys = list(h5f.keys())
random.shuffle(self.keys)
h5f.close()
def __len__(self):
return len(self.keys)
def __getitem__(self, index):
if self.train:
h5f = h5py.File('train.h5', 'r')
else:
h5f = h5py.File('val.h5', 'r')
key = self.keys[index]
data = np.array(h5f[key])
h5f.close()
return torch.Tensor(data)
最后还定义了一个类,鉴于我是半路出家,并不能理解class是怎么回事儿,所以我可能要先学习一下类,这样才能活学活用吧