import os
import glob
import random
from PIL import Image
def dataset(data, name):
boundary = int(len(data) * test_split_ratio) # 测试集和训练集的边界
print(len(data))
print(boundary)
for i, file in enumerate(data):
img = Image.open(file).convert('RGB')
old_size = img.size
ratio = float(desired_size)/max(old_size)
new_size = tuple([int(x*ratio) for x in old_size])
im = img.resize(new_size, Image.ANTIALIAS) # ANTIALIAS使图形不会模糊
new_im = Image.new('RGB', (desired_size, desired_size))
new_im.paste(im, ((desired_size-new_size[0])//2,
(desired_size-new_size[1])//2))
assert new_im.mode == 'RGB'
# 小于边界值应该放入测试集
if i <= boundary:
new_im.save(os.path.join(f'D:\\机器学习\\dogs-vs-cats\\test\\{name}',
file.split('\\')[-1]))
else:
new_im.save(os.path.join(f'D:\\机器学习\\dogs-vs-cats\\train\\{name}',
file.split('\\')[-1]))
print("ok")
if __name__ == '__main__':
test_split_ratio = 0.05
desired_size = 256 # 图片处理后的大小
raw_path = "D:\\机器学习\\dogs-vs-cats\\data" # 数据位置
cats = glob.glob(os.path.join(raw_path, 'cat.*.jpg'))
dogs = glob.glob(os.path.join(raw_path, 'dog.*.jpg'))
# dirs = [d for d in dirs if os.path.isdir(d)] # 只保留dirs中的目录部分
print(f'狗有{len(dogs)}') # 应该是2w5
print(f'猫有{len(cats)}') # 应该是2w5
# 原地shuffle取测试机和训练集
random.shuffle(cats)
random.shuffle(dogs)
dataset(cats, "cats")
dataset(dogs, "dogs")
计算均值和标准差,数据过多时会导致numpy计算能力不足
import os
import glob
import random
import shutil
import numpy as np
from PIL import Image
'''统计所有训练集图片各个通道的均值和标准差'''
if __name__ == '__main__':
cats_train_files = glob.glob(os.path.join(f'D:\\机器学习\\dogs-vs-cats\\test\\cats', '*'))
dogs_train_files = glob.glob(os.path.join(f'D:\\机器学习\\dogs-vs-cats\\test\\dogs', '*'))
print(f'cats{len(cats_train_files)}')
print(f'dogs{len(dogs_train_files)}')
train_files = dogs_train_files+cats_train_files
print(f'train Totally{len(train_files)}')
result = []
for file in train_files:
img = Image.open(file).convert('RGB')
img = np.array(img).astype(np.uint8)
img = img/255.
result.append(img)
print(np.shape(result)) # [BS, H, W, C]
mean = np.mean(result, axis=(0, 1, 2))
std = np.std(result, axis=(0, 1, 2))
print(f'mean={mean}')
print(f'std={std}')