import cv2 import numpy as np from sklearn.metrics import classification_report from sklearn.model_selection import train_test_split import matplotlib.pyplot as plt from torchvision.datasets import ImageFolder dataset = ImageFolder('E:\\pycharm\\data\\Images\\') a = cv2.imread('E:/pycharm/data/Images/airplane00.tif') # print(a.shape) dataset = dataset.imgs print(dataset) data_path = [] data_label = [] for i in range(dataset.__len__()): # print(dataset[i][0]) # print(dataset[i][1]) data_path.append(dataset[i][0]) data_label.append(dataset[i][1]) # print(data_path,'\n',data_label) data_arr=[] for i in range(len(data_path)): # print(data_path[i]) p = cv2.imread(data_path[i]) p = cv2.resize(p,(224,224)) # print(p.shape) # print(p) # data_arr[i]=np.array(p).resize(224*224*3) data_arr.append(np.array(p).flatten()) print(type(data_arr)) data_arr = np.array(data_arr) data_label = np.array(data_label) # # print(data_arr) X_train, X_test, y_train, y_test = train_test_split(data_arr, data_label, test_size=0.1, random_state=200) from sklearn import svm predictor = svm.SVC(gamma='scale', C=1.0,max_iter = 1000) # 进行训练 predictor.fit(X_train, y_train) predictions_labels = predictor.predict(X_test) print(classification_report(y_test, predictions_labels,digits=4))
UCM数据集21类分类
于 2024-06-12 12:49:04 首次发布