论文复现:Active Learning with the Furthest NearestNeighbor Criterion for Facial Age Estimation

Furthest Nearest Neighbor 方法就是其他文章中的Descripency方法,是一种diversity samplig方法。 

 

由于特征空间是不断变化的,在特征空间上使用Descripency方法违背了该准则的初衷。

import os
import torch
import numpy as np
from copy import deepcopy
from collections import OrderedDict
from PIL import Image
from sklearn.model_selection import StratifiedKFold


class FNN_2DLDA(object):
    def __init__(self, X_train, y_train, labeled, budget, X_test, y_test):
        self.X = X_train
        self.y = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.nSample = X_train.shape[0]
        print("样本个数=",self.nSample)
        self.labeled = list(deepcopy(labeled))     # 已标记样本的索引
        self.unlabeled = self.init_unlabeled_index()
        self.labels = np.sort(np.unique(y_train))  # 标签列表
        self.nClass = len(self.labels)
        self.budget = deepcopy(budget)
        self.nRow, self.nCol = self.X[0].shape     # 图像样本的行数和列数
        self.K = 10
        self.class_mean, self.global_mean, self.class_count = self.get_init_mean()
        self.S_bl, self.S_br, self.S_wl, self.S_wr = self.get_init_Sbl_Sbr_Swl_Swr()
        self.Wl, self.Wr = self.get_Wl_Wr()
        self.X_feature = self.get_feature()
        self.batch_size = 5

    def init_unlabeled_index(self):
        # =============无标记样本索引===============
        unlabeled = [i for i in range(self.nSample)]
        for idx in self.labeled:
            unlabeled.remove(idx)
        return unlabeled

    def get_init_mean(self):
        class_mean = torch.zeros((self.nClass, self.nRow, self.nCol))
        class_count = torch.zeros(self.nClass)
        global_mean = torch.zeros((self.nRow, self.nCol))
        # ========计算各类样本中心========
        for i in range(self.nClass):
            # ==获取第i个类的样本的索引==
            ids = []
            for idx in self.labeled:
                if self.y[idx] == self.labels[i]:
                    ids.append(idx)
            class_count[i] = len(ids)
            class_mean[i] = torch.mean(self.X[ids], dim=0)
        # ==========计算全局样本中心============
        for i in range(self.nClass):
            global_mean += (class_count[i] / len(self.labeled)) * class_mean[i]
        return class_mean, global_mean, class_count

    def get_init_Sbl_Sbr_Swl_Swr(self):
        # =============计算Sbl和Sbr=================
        S_bl = torch.zeros((self.nCol, self.nCol))
        S_br = torch.zeros((self.nRow, self.nRow))
        for i in range(self.nClass):
            tmp = self.class_mean[i] - self.global_mean
            S_bl += self.class_count[i] * torch.mm(tmp.T,tmp)
            S_br += self.class_count[i] * torch.mm(tmp, tmp.T)

        # =============计算Swl和Swr=================
        S_wl = torch.zeros((self.nCol, self.nCol))
        S_wr = torch.zeros((self.nRow, self.nRow))
        for i in range(self.nClass):
            for idx in self.labeled:
                if self.y[idx] == self.labels[i]:
                    tmp = self.X[idx] - self.class_mean[i]
                    S_wl += torch.mm(tmp.T, tmp)
                    S_wr += torch.mm(tmp, tmp.T)
        return S_bl, S_br, S_wl, S_wr

    def get_Wl_Wr(self):
        Wl_eigen_val, Wl_eigen_vec = torch.linalg.eig(torch.mm(torch.linalg.pinv(self.S_wl), self.S_bl))
        Wr_eigen_val, Wr_eigen_vec = torch.linalg.eig(torch.mm(torch.linalg.pinv(self.S_wr), self.S_br))
        odx_Wl = np.flipud(np.argsort(Wl_eigen_val))
        odx_Wr = np.flipud(np.argsort(Wr_eigen_val))
        Wl = torch.ones((self.nCol, self.K))
        Wr = torch.ones((self.K,self.nRow))
        for i in range(self.K):
            Wr[i] = Wr_eigen_vec[odx_Wr[i]]
            Wl[:,i] = Wl_eigen_vec[odx_Wl[i]]
        return Wl, Wr

    def get_feature(self):
        X_featrue = torch.zeros((self.nSample, self.K, self.K))
        for idx in range(self.nSample):
            X_featrue[idx] = torch.mm(torch.mm(self.Wr,self.X[idx]),self.Wl)
        return X_featrue

    def incremental_update_X_feature(self, selected):
        # ========update self.class_mean============
        for i in range(self.nClass):
            tmp_count = 0
            tmp_mean = torch.zeros((self.nRow, self.nCol))
            for idx in selected:
                if self.y[idx] == self.labels[i]:
                    tmp_count += 1
                    tmp_mean += self.X[idx]
            self.class_mean[i] = (self.class_count[i] * self.class_mean[i] + tmp_count * tmp_mean) / (self.class_count[i] + tmp_count)
            self.class_count[i] = self.class_count[i] + tmp_count
        # =========updata self.global_mean===========
        for i in range(self.nClass):
            self.global_mean = torch.zeros((self.nRow, self.nCol))
            self.global_mean += (self.class_count[i] / len(self.labeled)) * self.class_mean[i]

        # =========updata S_bl & S_br ===========
        self.S_bl = torch.zeros((self.nCol, self.nCol))
        self.S_br = torch.zeros((self.nRow, self.nRow))
        for i in range(self.nClass):
            tmp = self.class_mean[i] - self.global_mean
            self.S_bl += self.class_count[i] * torch.mm(tmp.T,tmp)
            self.S_br += self.class_count[i] * torch.mm(tmp, tmp.T)

        # =============update Swl & Swr=================
        self.S_wl = torch.zeros((self.nCol, self.nCol))
        self.S_wr = torch.zeros((self.nRow, self.nRow))
        for i in range(self.nClass):
            for idx in self.labeled:
                if self.y[idx] == self.labels[i]:
                    tmp = self.X[idx] - self.class_mean[i]
                    self.S_wl += torch.mm(tmp.T, tmp)
                    self.S_wr += torch.mm(tmp, tmp.T)


        Wl_eigen_val, Wl_eigen_vec = torch.linalg.eig(torch.mm(torch.linalg.pinv(self.S_wl), self.S_bl))
        Wr_eigen_val, Wr_eigen_vec = torch.linalg.eig(torch.mm(torch.linalg.pinv(self.S_wr), self.S_br))
        odx_Wl = np.flipud(np.argsort(Wl_eigen_val))
        odx_Wr = np.flipud(np.argsort(Wr_eigen_val))
        self.Wl = torch.ones((self.nCol, self.K))
        self.Wr = torch.ones((self.K,self.nRow))
        for i in range(self.K):
            self.Wr[i] = Wr_eigen_vec[odx_Wr[i]]
            self.Wl[:,i] = Wl_eigen_vec[odx_Wl[i]]
        # =============更新特征===============
        self.X_featrue = torch.zeros((self.nSample, self.K, self.K))
        for idx in range(self.nSample):
            self.X_featrue[idx] = torch.mm(torch.mm(self.Wr,self.X[idx]),self.Wl)

    def image_select(self, batch_size):
        metric_dict = OrderedDict()
        for idx in self.labeled:
            min_dist = np.inf
            min_index = None
            for jdx in self.unlabeled:
                dist_tmp = torch.norm(self.X_feature[idx] - self.X_feature[jdx])
                if dist_tmp < min_dist:
                    min_dist = dist_tmp
                    min_index = jdx
            metric_dict[(idx,min_index)] = min_dist
        selected = []
        for i in range(batch_size):
            tar_tuple = max(metric_dict, key=metric_dict.get)
            selected.append(tar_tuple[1])
            self.labeled.append(tar_tuple[1])
            self.labeled.append(tar_tuple[1])
            self.unlabeled.remove(tar_tuple[1])
            del metric_dict[tar_tuple]
            for idx in [tar_tuple[0], tar_tuple[1]]:
                min_dist = np.inf
                min_index = None
                for jdx in self.unlabeled:
                    dist_tmp = torch.norm(self.X_feature[idx] - self.X_feature[jdx])
                    if dist_tmp < min_dist:
                        min_dist = dist_tmp
                        min_index = jdx
                metric_dict[(idx,min_index)] = min_dist
        return selected

    def start(self):
        while self.budget > 0:
            if self.budget > self.batch_size:
                selected = self.image_select(batch_size=self.batch_size)
                self.budget -= self.batch_size
            else:
                selected = self.image_select(batch_size=self.budget)
                self.budget = 0
            print("selected::",selected)
            # ==========如果标记预算还没用完,则还要更新模型============
            if self.budget > 0:
                self.incremental_update_X_feature(selected=selected)







if __name__ == '__main__':
    path_dir = r"E:\PycharmProjects\DataSets\FaceData\yalefaces"
    # ===============基础信息=================
    nSample = 165
    nClass = 11
    labels = [i for i in np.arange(1,nClass+1)]
    nRow = 243
    nCol = 320
    Budget = 30
    # ==============构造标签==================
    y = np.zeros(165)
    i = 0
    label = 1
    j = 1
    while i < 165:
        if i+1 <= label*11:
            y[i] = label
            i += 1
        else:
            label +=1
    # ============读取图片数据=================
    X = torch.zeros((nSample, nRow, nCol))
    index = 0
    for name in os.listdir(path_dir):
        if name.split(".")[0][:7] == "subject":
            img = np.array(Image.open(path_dir + "\\" + name))
            X[index] = torch.from_numpy(img)
            index += 1

    SKF = StratifiedKFold(n_splits=5, shuffle=True)
    for train_idx, test_idx in SKF.split(X=X,y=y):
        X_train = X[train_idx]
        y_train = y[train_idx]
        X_test = X[test_idx]
        y_test = y[test_idx]
        labeled = []
        label_dict = OrderedDict()
        for lab in np.unique(y_train):
            label_dict[lab] = []
        for idx in range(len(y_train)):
            label_dict[y_train[idx]].append(idx)
        for idxlist in label_dict.values():
            for jdx in np.random.choice(idxlist,size=2, replace=False):
                labeled.append(jdx)

        model = FNN_2DLDA(X_train=X_train,y_train=y_train,labeled=labeled,budget=Budget,X_test=X_test,y_test=y_test)
        model.start()
        break

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

DeniuHe

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值