如何对深度学习模型进行类遗忘

    在看到这个标题的时候,大家可能会比较疑惑,模型难道不是能识别的类越多越好吗,为什么还要对好不容易学会的类进行遗忘呢?这个领域到现在还是一个新兴领域,主要是为了满足监管需求以及提高对用户隐私的保护。随着针对数据保护的法律在加强,如欧洲的GDPR通用数据保护条例,用户是有权利请求删除其个人数据的,这包括要求从训练有数据驱动模型的数据集中移除其相关数据。

   针对此需求,传统的方案通常会重新训练模型或者多次微调以适应每次的删除请求,不仅耗时且计算成本高,而且往往收效甚微。因为在深度学习的背景下,模型可能已经从大量数据中学习到复杂的标识,包含了各类别的特征,当需要从模型中移除掉特定类别的信息时,简单的删除原始数据并不够,模型往往还保留该类别的痕迹。
   所以针对此需求,给大家介绍一种新的思路,一种基于SVD的无梯度类遗忘算法,主要解决了传统算法在计算成本和样本需求的局限性。

原文地址:Deep Unlearning: Fast and Efficient Gradient-free Class Forgetting | OpenReview

该论文的方法不算特别复杂,主要步骤如下:
  • • 首先用SVD估计保留空间和遗忘空间,分别对应于需要保留的类别和需要遗忘的类别的特征的激活;

  • • 通过计算并去除这些空间的共享信息以隔离出类别的判别特征空间;

  • • 然后更新模型权重以最大程度的抑制遗忘类别样本的激活,从而得到遗忘类别样本后的模型

  • • 通过各种分析,包括成员推断攻击、基于显著性的特征分析和混淆矩阵分析,证明模型的行为与重新训练但不包含遗忘类别样本的模型行为一致。

具体算法步骤如下:

cfc72098329ecb696ebbe9125a444462.png

论文也提供了示例代码,仓库地址:

https://github.com/sangamesh-kodge/class_forgetting

在demo.py中实现了整个流程:

1、数据生成和准备:

def get_dataset(train_samples_per_class, test_samples_per_class, std = 0.5):
    print("-"*40)
    print("Generating Dataset")
    print("-"*40)
    means = [(1.0,1.0),(-1.0,1.0), (-1.0,-1.0), (1.0,-1.0) ]
    std_dev = (std,std)
    data_list = []
    target_list = []
    for classes in range(4):
        for sample in tqdm(range(train_samples_per_class), desc=f"Train class {classes}"):
            data_list.append(torch.normal(mean= torch.tensor(means[classes]), std= torch.tensor(std_dev) ))
            target_list.append(torch.tensor(classes).long())
    data_tensor = torch.stack(data_list, 0)
    target_tensor = torch.stack(target_list, 0)
    train_data = (data_tensor, target_tensor)


    data_list = []
    target_list = []
    for classes in range(4):
        for sample in tqdm(range(test_samples_per_class), desc=f"Test class {classes}"):
            data_list.append(torch.normal(mean= torch.tensor(means[classes]), std= torch.tensor(std) ))
            target_list.append(torch.tensor(classes).long())
    data_tensor = torch.stack(data_list, 0)
    target_tensor = torch.stack(target_list, 0)
    test_data = (data_tensor, target_tensor)
    return train_data, test_data

2、网络模型定义:

class Linear5(nn.Module):
    def __init__(self, in_feature=2, hidden_features=5, num_classes=4):
        super(Linear5, self).__init__()
        self.fc1 = nn.Linear(in_features=in_feature, out_features=hidden_features)
        self.bn1 = nn.BatchNorm1d(hidden_features)
        self.fc2 = nn.Linear(in_features=hidden_features, out_features=hidden_features)
        self.bn2 = nn.BatchNorm1d(hidden_features)
        self.fc3 = nn.Linear(in_features=hidden_features, out_features=hidden_features)
        self.bn3 = nn.BatchNorm1d(hidden_features)
        self.fc4 = nn.Linear(in_features=hidden_features, out_features=hidden_features)
        self.bn4 = nn.BatchNorm1d(hidden_features)
        self.fc5 = nn.Linear(in_features=hidden_features, out_features=num_classes)


    def forward(self,x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = F.relu(self.bn3(self.fc3(x)))
        x = F.relu(self.bn4(self.fc4(x)))
        return F.log_softmax(self.fc5(x), 1)
    
    def get_activations(self, x):
        act = {"pre":OrderedDict(), "post":OrderedDict()}
        act["pre"]["fc1"] = deepcopy(x.clone().detach().cpu().numpy())
        x = self.fc1(x)
        act["post"]["fc1"] = deepcopy(x.clone().detach().cpu().numpy())
        x = F.relu(self.bn1(x))
        act["pre"]["fc2"] = deepcopy(x.clone().detach().cpu().numpy())
        x = self.fc2(x)
        act["post"]["fc2"] = deepcopy(x.clone().detach().cpu().numpy())
        x = F.relu(self.bn2(x))
        act["pre"]["fc3"] = deepcopy(x.clone().detach().cpu().numpy())
        x = self.fc3(x)
        act["post"]["fc3"] = deepcopy(x.clone().detach().cpu().numpy())
        x = F.relu(self.bn3(x))
        act["pre"]["fc4"] = deepcopy(x.clone().detach().cpu().numpy())
        x = self.fc4(x)
        act["post"]["fc4"] = deepcopy(x.clone().detach().cpu().numpy())
        x = F.relu(self.bn4(x))
        act["pre"]["fc5"] = deepcopy(x.clone().detach().cpu().numpy())
        x = self.fc5(x)
        act["post"]["fc5"] = deepcopy(x.clone().detach().cpu().numpy())
        return act 
    
    def project_weights(self, projection_mat_dict):
        self.fc1.weight.data = torch.mm(projection_mat_dict["post"]["fc1"].transpose(0,1), torch.mm(self.fc1.weight.data, projection_mat_dict["pre"]["fc1"].transpose(0,1)))
        self.fc1.bias.data = torch.mm(self.fc1.bias.data.unsqueeze(0), projection_mat_dict["post"]["fc1"]).squeeze(0)


        self.fc2.weight.data = torch.mm(projection_mat_dict["post"]["fc2"].transpose(0,1), torch.mm(self.fc2.weight.data, projection_mat_dict["pre"]["fc2"].transpose(0,1)))
        self.fc2.bias.data = torch.mm(self.fc2.bias.data.unsqueeze(0), projection_mat_dict["post"]["fc2"]).squeeze(0)


        self.fc3.weight.data = torch.mm(projection_mat_dict["post"]["fc3"].transpose(0,1), torch.mm(self.fc3.weight.data, projection_mat_dict["pre"]["fc3"].transpose(0,1)))
        self.fc3.bias.data = torch.mm(self.fc3.bias.data.unsqueeze(0), projection_mat_dict["post"]["fc3"]).squeeze(0)


        self.fc4.weight.data = torch.mm(projection_mat_dict["post"]["fc4"].transpose(0,1), torch.mm(self.fc4.weight.data, projection_mat_dict["pre"]["fc4"].transpose(0,1)))
        self.fc4.bias.data = torch.mm(self.fc4.bias.data.unsqueeze(0), projection_mat_dict["post"]["fc4"]).squeeze(0)


        self.fc5.weight.data =  torch.mm(self.fc5.weight.data, projection_mat_dict["pre"]["fc5"].transpose(0,1))
        return

3、模型训练和评估:

def train(model, train_loader, device, optimizer, schedular):
    model.train()
    total_loss = 0
    for data, target in tqdm(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        log_prob = model(data)
        loss = F.nll_loss(log_prob, target)
        total_loss+= loss.detach().item()/len(train_loader.dataset)
        loss.backward()
        optimizer.step()
        schedular.step(loss)
    return total_loss
        


def test(model, test_loader, device):
    model.eval()
    total_loss = 0
    correct = 0
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        log_prob = model(data)
        pred = log_prob.argmax(dim=1, keepdim=True)
        loss = F.nll_loss(log_prob, target)
        total_loss+= loss.detach().item()/len(test_loader.dataset)
        correct += pred.eq(target.view_as(pred)).sum().item()
    acc = 100*correct/len(test_loader.dataset)
    return total_loss, acc

4、SVD遗忘空间的计算

# Compute Ur
    for loc in retain_act.keys():
        for key in retain_act[loc].keys():
            activation = torch.Tensor(retain_act[loc][key]).to("cuda").transpose(0,1)
            U,S,Vh = torch.linalg.svd(activation, full_matrices=False)
            U = U.cpu().numpy()
            S = S.cpu().numpy()  
            retain_mat_dict[loc][key] = U
            retain_normalize_var_mat_dict[loc][key] = S**2 / (S**2).sum()


    forget_data = data1[0][:train_samples_per_class]
    forget_act = model.get_activations(forget_data.to(device))
    forget_mat_dict= {"pre":OrderedDict(), "post":OrderedDict()}
    forget_normalize_var_mat_dict = {"pre":OrderedDict(), "post":OrderedDict()}
    # Compute Uf
    for loc in forget_act.keys():
        for key in forget_act[loc].keys():
            activation = torch.Tensor(forget_act[loc][key]).to("cuda").transpose(0,1)
            U,S,Vh = torch.linalg.svd(activation, full_matrices=False)
            U = U.cpu().numpy()
            S = S.cpu().numpy()  
            forget_mat_dict[loc][key] = U
            forget_normalize_var_mat_dict[loc][key] = S**2 / (S**2).sum()

在main.py中也提供了其他不同的unlearn算法,如SSD、OUR、TARUN、SCRUB、GOEL和SALUN

SALUN:

def salun_unlearn(args, model, device, retain_loader, forget_loader, train_loader, test_loader, optimizer, epochs, **kwargs):
    mask = get_salun_mask(args, model, device, forget_loader)     
    model.train()
    train_loss= 0
    train_forget_acc =100.0
    steps = 0
    # Random relabeling
    forget_dataset = copy.deepcopy(forget_loader.dataset)
    forget_dataset.update_labels(np.random.randint(0, args.num_classes, len(forget_dataset) ))
    forget_dataset, _ = torch.utils.data.random_split(forget_dataset, [args.num_forget_samples, len(forget_dataset) - args.num_forget_samples])
    
    #Subsample retain set. 
    retain_dataset = retain_loader.dataset
    retain_dataset, _ = torch.utils.data.random_split(retain_dataset, [args.num_retain_samples, len(retain_dataset) - args.num_retain_samples])


    train_dataset = torch.utils.data.ConcatDataset([forget_dataset,retain_dataset])


    salun_train_loader= torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True)


    max_steps = epochs*(len(train_dataset)//args.batch_size)


    for epoch in range(epochs):
        for data, target in salun_train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.nll_loss(output, target)
            optimizer.zero_grad()
            loss.backward()
            steps+=1
            if mask:
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        param.grad *= mask[name]
            optimizer.step()
            train_loss += loss.detach().item()


 
            if args.dry_run:
                break        
        if args.dry_run:
            break    
    return model

GOEL:

def goel_last_unlearn(args, model, device, retain_loader, forget_loader, train_loader, test_loader, optimizer, epochs, **kwargs):
    train_loss= 0
    train_forget_acc =100.0
    steps = 0    
    #Subsample retain set. 
    retain_dataset = retain_loader.dataset
    retain_dataset, _ = torch.utils.data.random_split(retain_dataset, [args.num_retain_samples, len(retain_dataset) - args.num_retain_samples])
    
    goel_train_loader= torch.utils.data.DataLoader( retain_dataset, batch_size=args.batch_size, shuffle=True)
  
    model.train()
    model = resetFinalResnet(model, 1, reinit=args.goel_exact)
    
    for epoch in range(epochs):
        for data, target in goel_train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.nll_loss(output, target)
            optimizer.zero_grad()
            loss.backward()
            steps+=1
            optimizer.step()
            train_loss += loss.detach().item()
                    
            if args.dry_run:
                break        
        if args.dry_run:
            break  
    return model

SSD:

def ssd_unlearn(args, model, device, retain_loader, forget_loader, train_loader, test_loader, optimizer, **kwargs):
    #Subsample dataset. 
    forget_dataset = forget_loader.dataset
    forget_dataset, _ = torch.utils.data.random_split(forget_dataset, [args.num_forget_samples, len(forget_dataset) - args.num_forget_samples])
    retain_dataset = retain_loader.dataset
    retain_dataset, _ = torch.utils.data.random_split(retain_dataset, [args.num_retain_samples, len(retain_dataset) - args.num_retain_samples])
    ssd_retain_loader= torch.utils.data.DataLoader( retain_dataset, batch_size=args.batch_size , shuffle=True)
    train_dataset = torch.utils.data.ConcatDataset([forget_loader.dataset,retain_dataset])
    ssd_train_loader= torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True)


    # Hyperparameters
    lower_bound = 1  # unused
    exponent = 1  # unused
    dampening_constant =  args.ssd_lambda # Lambda from paper
    selection_weighting = args.ssd_alpha # Alpha from paper


    model.eval() # to ensure batch statistics do not change


    # Calculation of the forget set importances
    forget_importance = calc_importance(model, device, optimizer, forget_loader)
    # Calculate the importances of D (see paper); this can also be done at any point before forgetting.
    original_importance = calc_importance(model, device, optimizer,ssd_train_loader)
    # Dampen selected parameters
    with torch.no_grad():
        for (n, p), (oimp_n, oimp), (fimp_n, fimp) in zip(
            model.named_parameters(),
            original_importance.items(),
            forget_importance.items(),
        ):
            # Synapse Selection with parameter alpha
            oimp_norm = oimp.mul(selection_weighting)
            locations = torch.where(fimp > oimp_norm)


            # Synapse Dampening with parameter lambda
            weight = ((oimp.mul(dampening_constant)).div(fimp)).pow(
                exponent
            )
            update = weight[locations]
            # Bound by 1 to prevent parameter values to increase.
            min_locs = torch.where(update > lower_bound)
            update[min_locs] = lower_bound
            p[locations] = p[locations].mul(update)  
    return model

最终研究人员在ImageNet数据集上使用视觉变换器验证了他们的算法,结果显示,在仅损失约1.5%保留类精度的情况下,可以显著降低对遗忘类别的准确率,不到1%,并且在面对成员推断攻击时展现出韧性。与基线方法相比,新算法在ImageNet数据集上平均提高了1.38%的准确性,同时所需遗忘样本数量减少了10倍,此外,该算法在CIFAR-100数据集上的ResNet18架构下对更强的类别推断攻击也表现出了优于最佳基线1.8%的结果,感兴趣的同学可以跑一下demo代码。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值