Furthest Nearest Neighbor 方法就是其他文章中的Descripency方法,是一种diversity samplig方法。
由于特征空间是不断变化的,在特征空间上使用Descripency方法违背了该准则的初衷。
import os
import torch
import numpy as np
from copy import deepcopy
from collections import OrderedDict
from PIL import Image
from sklearn.model_selection import StratifiedKFold
class FNN_2DLDA(object):
def __init__(self, X_train, y_train, labeled, budget, X_test, y_test):
self.X = X_train
self.y = y_train
self.X_test = X_test
self.y_test = y_test
self.nSample = X_train.shape[0]
print("样本个数=",self.nSample)
self.labeled = list(deepcopy(labeled)) # 已标记样本的索引
self.unlabeled = self.init_unlabeled_index()
self.labels = np.sort(np.unique(y_train)) # 标签列表
self.nClass = len(self.labels)
self.budget = deepcopy(budget)
self.nRow, self.nCol = self.X[0].shape # 图像样本的行数和列数
self.K = 10
self.class_mean, self.global_mean, self.class_count = self.get_init_mean()
self.S_bl, self.S_br, self.S_wl, self.S_wr = self.get_init_Sbl_Sbr_Swl_Swr()
self.Wl, self.Wr = self.get_Wl_Wr()
self.X_feature = self.get_feature()
self.batch_size = 5
def init_unlabeled_index(self):
# =============无标记样本索引===============
unlabeled = [i for i in range(self.nSample)]
for idx in self.labeled:
unlabeled.remove(idx)
return unlabeled
def get_init_mean(self):
class_mean = torch.zeros((self.nClass, self.nRow, self.nCol))
class_count = torch.zeros(self.nClass)
global_mean = torch.zeros((self.nRow, self.nCol))
# ========计算各类样本中心========
for i in range(self.nClass):
# ==获取第i个类的样本的索引==
ids = []
for idx in self.labeled:
if self.y[idx] == self.labels[i]:
ids.append(idx)
class_count[i] = len(ids)
class_mean[i] = torch.mean(self.X[ids], dim=0)
# ==========计算全局样本中心============
for i in range(self.nClass):
global_mean += (class_count[i] / len(self.labeled)) * class_mean[i]
return class_mean, global_mean, class_count
def get_init_Sbl_Sbr_Swl_Swr(self):
# =============计算Sbl和Sbr=================
S_bl = torch.zeros((self.nCol, self.nCol))
S_br = torch.zeros((self.nRow, self.nRow))
for i in range(self.nClass):
tmp = self.class_mean[i] - self.global_mean
S_bl += self.class_count[i] * torch.mm(tmp.T,tmp)
S_br += self.class_count[i] * torch.mm(tmp, tmp.T)
# =============计算Swl和Swr=================
S_wl = torch.zeros((self.nCol, self.nCol))
S_wr = torch.zeros((self.nRow, self.nRow))
for i in range(self.nClass):
for idx in self.labeled:
if self.y[idx] == self.labels[i]:
tmp = self.X[idx] - self.class_mean[i]
S_wl += torch.mm(tmp.T, tmp)
S_wr += torch.mm(tmp, tmp.T)
return S_bl, S_br, S_wl, S_wr
def get_Wl_Wr(self):
Wl_eigen_val, Wl_eigen_vec = torch.linalg.eig(torch.mm(torch.linalg.pinv(self.S_wl), self.S_bl))
Wr_eigen_val, Wr_eigen_vec = torch.linalg.eig(torch.mm(torch.linalg.pinv(self.S_wr), self.S_br))
odx_Wl = np.flipud(np.argsort(Wl_eigen_val))
odx_Wr = np.flipud(np.argsort(Wr_eigen_val))
Wl = torch.ones((self.nCol, self.K))
Wr = torch.ones((self.K,self.nRow))
for i in range(self.K):
Wr[i] = Wr_eigen_vec[odx_Wr[i]]
Wl[:,i] = Wl_eigen_vec[odx_Wl[i]]
return Wl, Wr
def get_feature(self):
X_featrue = torch.zeros((self.nSample, self.K, self.K))
for idx in range(self.nSample):
X_featrue[idx] = torch.mm(torch.mm(self.Wr,self.X[idx]),self.Wl)
return X_featrue
def incremental_update_X_feature(self, selected):
# ========update self.class_mean============
for i in range(self.nClass):
tmp_count = 0
tmp_mean = torch.zeros((self.nRow, self.nCol))
for idx in selected:
if self.y[idx] == self.labels[i]:
tmp_count += 1
tmp_mean += self.X[idx]
self.class_mean[i] = (self.class_count[i] * self.class_mean[i] + tmp_count * tmp_mean) / (self.class_count[i] + tmp_count)
self.class_count[i] = self.class_count[i] + tmp_count
# =========updata self.global_mean===========
for i in range(self.nClass):
self.global_mean = torch.zeros((self.nRow, self.nCol))
self.global_mean += (self.class_count[i] / len(self.labeled)) * self.class_mean[i]
# =========updata S_bl & S_br ===========
self.S_bl = torch.zeros((self.nCol, self.nCol))
self.S_br = torch.zeros((self.nRow, self.nRow))
for i in range(self.nClass):
tmp = self.class_mean[i] - self.global_mean
self.S_bl += self.class_count[i] * torch.mm(tmp.T,tmp)
self.S_br += self.class_count[i] * torch.mm(tmp, tmp.T)
# =============update Swl & Swr=================
self.S_wl = torch.zeros((self.nCol, self.nCol))
self.S_wr = torch.zeros((self.nRow, self.nRow))
for i in range(self.nClass):
for idx in self.labeled:
if self.y[idx] == self.labels[i]:
tmp = self.X[idx] - self.class_mean[i]
self.S_wl += torch.mm(tmp.T, tmp)
self.S_wr += torch.mm(tmp, tmp.T)
Wl_eigen_val, Wl_eigen_vec = torch.linalg.eig(torch.mm(torch.linalg.pinv(self.S_wl), self.S_bl))
Wr_eigen_val, Wr_eigen_vec = torch.linalg.eig(torch.mm(torch.linalg.pinv(self.S_wr), self.S_br))
odx_Wl = np.flipud(np.argsort(Wl_eigen_val))
odx_Wr = np.flipud(np.argsort(Wr_eigen_val))
self.Wl = torch.ones((self.nCol, self.K))
self.Wr = torch.ones((self.K,self.nRow))
for i in range(self.K):
self.Wr[i] = Wr_eigen_vec[odx_Wr[i]]
self.Wl[:,i] = Wl_eigen_vec[odx_Wl[i]]
# =============更新特征===============
self.X_featrue = torch.zeros((self.nSample, self.K, self.K))
for idx in range(self.nSample):
self.X_featrue[idx] = torch.mm(torch.mm(self.Wr,self.X[idx]),self.Wl)
def image_select(self, batch_size):
metric_dict = OrderedDict()
for idx in self.labeled:
min_dist = np.inf
min_index = None
for jdx in self.unlabeled:
dist_tmp = torch.norm(self.X_feature[idx] - self.X_feature[jdx])
if dist_tmp < min_dist:
min_dist = dist_tmp
min_index = jdx
metric_dict[(idx,min_index)] = min_dist
selected = []
for i in range(batch_size):
tar_tuple = max(metric_dict, key=metric_dict.get)
selected.append(tar_tuple[1])
self.labeled.append(tar_tuple[1])
self.labeled.append(tar_tuple[1])
self.unlabeled.remove(tar_tuple[1])
del metric_dict[tar_tuple]
for idx in [tar_tuple[0], tar_tuple[1]]:
min_dist = np.inf
min_index = None
for jdx in self.unlabeled:
dist_tmp = torch.norm(self.X_feature[idx] - self.X_feature[jdx])
if dist_tmp < min_dist:
min_dist = dist_tmp
min_index = jdx
metric_dict[(idx,min_index)] = min_dist
return selected
def start(self):
while self.budget > 0:
if self.budget > self.batch_size:
selected = self.image_select(batch_size=self.batch_size)
self.budget -= self.batch_size
else:
selected = self.image_select(batch_size=self.budget)
self.budget = 0
print("selected::",selected)
# ==========如果标记预算还没用完,则还要更新模型============
if self.budget > 0:
self.incremental_update_X_feature(selected=selected)
if __name__ == '__main__':
path_dir = r"E:\PycharmProjects\DataSets\FaceData\yalefaces"
# ===============基础信息=================
nSample = 165
nClass = 11
labels = [i for i in np.arange(1,nClass+1)]
nRow = 243
nCol = 320
Budget = 30
# ==============构造标签==================
y = np.zeros(165)
i = 0
label = 1
j = 1
while i < 165:
if i+1 <= label*11:
y[i] = label
i += 1
else:
label +=1
# ============读取图片数据=================
X = torch.zeros((nSample, nRow, nCol))
index = 0
for name in os.listdir(path_dir):
if name.split(".")[0][:7] == "subject":
img = np.array(Image.open(path_dir + "\\" + name))
X[index] = torch.from_numpy(img)
index += 1
SKF = StratifiedKFold(n_splits=5, shuffle=True)
for train_idx, test_idx in SKF.split(X=X,y=y):
X_train = X[train_idx]
y_train = y[train_idx]
X_test = X[test_idx]
y_test = y[test_idx]
labeled = []
label_dict = OrderedDict()
for lab in np.unique(y_train):
label_dict[lab] = []
for idx in range(len(y_train)):
label_dict[y_train[idx]].append(idx)
for idxlist in label_dict.values():
for jdx in np.random.choice(idxlist,size=2, replace=False):
labeled.append(jdx)
model = FNN_2DLDA(X_train=X_train,y_train=y_train,labeled=labeled,budget=Budget,X_test=X_test,y_test=y_test)
model.start()
break