查看是否存在样本不均衡问题,同时方便选择合适的数据预处理方式。
from glob import glob
import pandas as pd
import numpy as np
import os
import cv2
from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm
TRAIN_DATASET_PATH = '../train_data'
image_fns = glob(os.path.join(TRAIN_DATASET_PATH, '*', '*.*'))
label_names = [s.split('/')[-2] for s in image_fns]
unique_labels = list(set(label_names))
print(len(unique_labels))
print(len(image_fns))
dir_lst = os.listdir(TRAIN_DATASET_PATH)
number_lst = []
for i in dir_lst:
path = os.path.join(TRAIN_DATASET_PATH,i)
if os.path.isdir(path):
num = len(glob(os.path.join(path,'*')))
number_lst.append(num)
if num == 0:
print(i)
plt.hist(number_lst, bins=40, normed=0, facecolor="blue", edgecolor="black", alpha=0.7);
print(np.max(number_lst))
print(np.min(number_lst))
dir_lst = os.listdir(TRAIN_DATASET_PATH)
number_lst = []
size_lst = []
for i in tqdm(dir_lst):
path = os.path.join(TRAIN_DATASET_PATH,i)
if os.path.isdir(path):
img_lst = glob(os.path.join(path,'*'))
for j in img_lst:
size_lst.append(Image.open(j).size)
temp = pd.value_counts(size_lst)
print(temp[temp>200])
rd_index = np.random.randint(len(image_fns))
plt.imshow(plt.imread(image_fns[rd_index]))
TRAIN_DATASET_PATH = '../test_data_A'
image_gal = glob(os.path.join(TRAIN_DATASET_PATH, 'gallery', '*.*'))
image_que = glob(os.path.join(TRAIN_DATASET_PATH, 'query', '*.*'))
print(len(image_gal))
print(len(image_que))
size_lst_gal = []
for i in tqdm(image_gal):
size_lst_gal.append(Image.open(i).size)
size_lst_que = []
for i in tqdm(image_que):
size_lst_que.append(Image.open(i).size)
temp_gal = pd.value_counts(size_lst_gal)
temp_que = pd.value_counts(size_lst_que)
print(temp_gal[temp_gal>200])
print(temp_que[temp_que>100])