from sklearn import cross_validationimport cv2
def load_image_data(): dir_path=r'E:\' # class_1_path=dir_path+r'\no' # class_2_path=dir_path+r'\other' # class_3_path=dir_path+r'\yes' X=[] Y=[] testX=[] testY=[] #print(class_1_path) filepathvec = [] for rt, dirs, files in os.walk(dir_path): # =pathDir # print('rt:',rt) # print('dirs:',dirs) for dir in dirs: # print(dir) # if dir.find('-')<0: # continue for rt1, dirs, filenames in os.walk(rt+'\\'+dir): file_count=0 for filename in filenames: file_count+=1 if file_count>200: break # print (filename) if filename.find('.') >= 0: (shotname, extension) = os.path.splitext(filename) # print shotname,extension if extension == '.png' or extension == '.tif': filepathvec.append(os.path.join('%s\\%s' % (rt+'\\'+dir, filename))) # print (filename) break # exit() print(len(filepathvec)) for file in filepathvec: # if len(X)>3200: # break print(file) img=cv2.imread(file,0) # print(np.shape(img)) #cv2.namedWindow("the window") #cv2.imshow("the window",img) image_width,image_height=np.shape(img) print(image_width) fzoom=20.0/float(image_width) img=cv2.resize(img,(int(float(image_height)*fzoom),20), interpolation=cv2.INTER_CUBIC) img=cv2.rotate(img,cv2.ROTATE_90_CLOCKWISE) img=255-img #反色处理 # cv2.imshow('xx', img) # cv2.waitKey() #print(img) #print(len(img)) #print(np.shape(img)) X.append(img) Y.append(get_label(file)) #cv2.namedWindow("the window") #cv2.imshow("the window",img) #break #多块平移提取处理 # img_src_height=np.shape(img)[0] # img_src_width=np.shape(img)[1] # for i in range(1,3): # img_crop=np.array(img[i:-1][i:-1]) # img_crop = cv2.resize(img_crop, (28, 28), interpolation=cv2.INTER_CUBIC) # img_crop = 255 - img_crop # 反色处理 # X.append(img_crop) # Y.append(get_label(file)) # # for i in range(1,3): # img_crop = img[0:i][0:i] # img_crop = cv2.resize(img_crop, (28, 28), interpolation=cv2.INTER_CUBIC) # img_crop = 255 - img_crop # 反色处理 # X.append(img_crop) # Y.append(get_label(file)) # print(X[0:5]) # print(Y[0:5]) print(len(X),len(Y)) X=np.array(X) Y=np.array(Y) # testX=X # testY=Y X, testX, Y, testY = cross_validation.train_test_split(X, Y, test_size=0.2) print('train set:',len(X)) print('test set:',len(testX)) return X,Y,testX,testY