基于(BOVW-SIFT)的SAR遥感影像分类
文章目录
导入库
import numpy as np
import cv2
import torch
import torchvision
import matplotlib.pyplot as plt
import tqdm.auto as tqdm
from sklearn.cluster import KMeans
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC
from sklearn.metrics import plot_confusion_matrix
def show_results(clf, X, y, imgs, show_positive=True):
preds = clf.predict(X)
idxs = np.where((preds == y) == show_positive)[0]
fig = plt.figure(figsize=(20, 10))
show_idx = []
for i in range(8):
plt.subplot(2, 4, i + 1)
idx = np.random.randint(idxs.shape[0])
while idx in show_idx:
idx = np.random.randint(idxs.shape[0])
show_idx.append(idx)
val = idxs[idx]
plt.title(f"Pred={preds[val]}, Label={y[val]}",fontsize=20,fontname="Times New Roman")
plt.axis('off')
plt.imshow(imgs[val])
plt.show()
加载数据
transform_img = torchvision.transforms.Compose([
torchvision.transforms.Resize((150, 150)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Lambda(lambda x: x.permute(2, 1, 0).permute(1, 0, 2) * 255)
])
dataset = torchvision.datasets.ImageFolder(root='./dataset/OpenSARUrban/train/', transform=transform_img)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=True)
img_arr, label_arr = next(iter(train_dataloader))
img_arr = img_arr.numpy().astype(np.uint8)
label_arr = label_arr.numpy()
plt.imshow(img_arr[100])
print(label_arr[100])
5
img_arr_gray = np.zeros(img_arr.shape[:-1], dtype=np.uint8)
for i in range(img_arr.shape[0]):
img_arr_gray[i] = cv2.cvtColor(img_arr[i], cv2.COLOR_RGB2GRAY)
plt.imshow(img_arr_gray[0], cmap='gray')
print(label_arr[0])
6
执行 Dense SIFT
step_size = 10
kp = [cv2.KeyPoint(x, y, step_size) for y in range(0, 150, step_size)
for x in range(0, 150, step_size)]
desc_arr = np.zeros((img_arr.shape[0], len(kp), 128))
sift = cv2.SIFT_create()
for i in tqdm.tqdm(range(img_arr.shape[0])):
_, desc = sift.compute(img_arr_gray[i], kp)
desc_arr[i] = desc
100%|██████████| 6720/6720 [00:46<00:00, 145.30it/s]
desc_arr.shape
(6720, 225, 128)
reshaped_desc = desc_arr.reshape(-1, 128)
reshaped_desc.shape
(1512000, 128)
训练 KMeans
N_CLUSTERS = 500
kmeans = KMeans(n_clusters=N_CLUSTERS, n_init=3).fit(reshaped_desc)
kmeans_labels = kmeans.labels_.reshape(desc_arr.shape[:-1])
训练
准备归一化直方图
freq_table = np.zeros((img_arr.shape[0], N_CLUSTERS))
freq_table.shape
(6720, 500)
for i in range(N_CLUSTERS):
freq_table[:, i] = np.sum(kmeans_labels == i, axis=1)
准备 TF-IDF
tf = np.log(freq_table + 1)
tf_Doc = np.zeros(N_CLUSTERS)
for i in range(N_CLUSTERS):
tf_Doc[i] = np.sum(np.sum(kmeans_labels == i, axis=1) != 0)
idf = np.log(img_arr.shape[0]/(1 + tf_Doc))
tfidf = tf * idf
freq_table = freq_table/kmeans_labels.shape[-1]
Training using LinearSVC with One vs Rest Multiclass Setting
clf = LinearSVC(max_iter=10000, C=10).fit(freq_table, label_arr)
tfidf_clf = LinearSVC(max_iter=20000, C=0.001).fit(tfidf, label_arr)
训练结果
print(f"Normalized Histogram - Training Accuracy: {np.mean(clf.predict(freq_table) == label_arr)}")
print(f"TFIDF - Training Accuracy: {np.mean(tfidf_clf.predict(tfidf) == label_arr)}")
Normalized Histogram - Training Accuracy: 0.5047619047619047
TFIDF - Training Accuracy: 0.5223214285714286
plot_confusion_matrix(clf, freq_table, label_arr)
plt.title("")
plt.show()
show_results(clf, freq_table, label_arr, img_arr, show_positive=True)
show_results(clf, freq_table, label_arr, img_arr, show_positive=False)
plot_confusion_matrix(tfidf_clf, tfidf, label_arr)
plt.title("Confusion Matrix Training - TFIDF")
plt.show()
show_results(tfidf_clf, tfidf, label_arr, img_arr, show_positive=True)
show_results(tfidf_clf, tfidf, label_arr, img_arr, show_positive=False)
我们可以看到,与归一化直方图相比,TF-IDF 的性能略有下降。 这是意料之中的,因为 TF-IDF 参数调整可能非常关键。
TF-IDF 应该通过提供更好的条件数来提高算法的收敛性。 性能通常不会有很大的改进,因为在这种情况下,通常很少有噪声方向可以以这种方式修剪。 另一方面,最终性能更依赖于 LinearSVC
的参数调整。
测试集
加载数据集
transform_img = torchvision.transforms.Compose([
torchvision.transforms.Resize((150, 150)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Lambda(lambda x: x.permute(2, 1, 0).permute(1, 0, 2) * 255)
])
test_dataset = torchvision.datasets.ImageFolder(root='./dataset/OpenSARUrban/test/', transform=transform_img)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset))
test_img_arr = next(iter(test_dataloader))[0].numpy().astype(np.uint8)
test_label_arr = next(iter(test_dataloader))[1].numpy()
test_img_arr_gray = np.zeros(test_img_arr.shape[:-1], dtype=np.uint8)
for i in range(test_img_arr.shape[0]):
test_img_arr_gray[i] = cv2.cvtColor(test_img_arr[i], cv2.COLOR_RGB2GRAY)
plt.imshow(test_img_arr_gray[0], cmap='gray')
<matplotlib.image.AxesImage at 0x227334a7250>
执行 Dense SIFT
step_size = 10
kp = [cv2.KeyPoint(x, y, step_size) for y in range(0, 150, step_size)
for x in range(0, 150, step_size)]
desc_arr = np.zeros((test_img_arr.shape[0], len(kp), 128))
sift = cv2.SIFT_create()
for i in tqdm.tqdm(range(test_img_arr.shape[0])):
_, desc = sift.compute(test_img_arr_gray[i], kp)
desc_arr[i] = desc
100%|██████████| 1680/1680 [00:09<00:00, 175.63it/s]
分配 KMeans 标签
kmeans_labels_test = kmeans.predict(desc_arr.reshape(-1, 128)).reshape(desc_arr.shape[:-1])
kmeans_labels_test.shape
(1680, 225)
准备归一化直方图
test_freq_table = np.zeros((test_img_arr.shape[0], N_CLUSTERS))
test_freq_table.shape
(1680, 500)
for i in range(N_CLUSTERS):
test_freq_table[:, i] = np.sum(kmeans_labels_test == i, axis=1)
准备 TF-IDF
test_tf = np.log(test_freq_table + 1)
test_tf_Doc = np.zeros(N_CLUSTERS)
for i in range(N_CLUSTERS):
test_tf_Doc[i] = np.sum(np.sum(kmeans_labels_test == i, axis=1) != 0)
test_idf = np.log(test_img_arr.shape[0]/(1 + test_tf_Doc))
test_tfidf = test_tf * test_idf
test_freq_table = test_freq_table/kmeans_labels_test.shape[-1]
测试数据集结果
print(f"Normalized Histogram - Test Accuracy: {np.mean(clf.predict(test_freq_table) == test_label_arr)}")
print(f"TFIDF - Test Accuracy: {np.mean(tfidf_clf.predict(test_tfidf) == test_label_arr)}")
Normalized Histogram - Test Accuracy: 0.4232142857142857
TFIDF - Test Accuracy: 0.4172619047619048
plot_confusion_matrix(clf, test_freq_table, test_label_arr,cmap='coolwarm')
plt.title('')#("Confusion Matrix Test - Normalized Histograms")
plt.show()
show_results(clf, test_freq_table, test_label_arr, test_img_arr, show_positive=True)
show_results(clf, test_freq_table, test_label_arr, test_img_arr, show_positive=False)
plot_confusion_matrix(tfidf_clf, test_tfidf, test_label_arr)
plt.title("Confusion Matrix Test - TFIDF")
plt.show()
show_results(tfidf_clf, test_tfidf, test_label_arr, test_img_arr, show_positive=True)
show_results(tfidf_clf, test_tfidf, test_label_arr, test_img_arr, show_positive=False)