在看到这个标题的时候,大家可能会比较疑惑,模型难道不是能识别的类越多越好吗,为什么还要对好不容易学会的类进行遗忘呢?这个领域到现在还是一个新兴领域,主要是为了满足监管需求以及提高对用户隐私的保护。随着针对数据保护的法律在加强,如欧洲的GDPR通用数据保护条例,用户是有权利请求删除其个人数据的,这包括要求从训练有数据驱动模型的数据集中移除其相关数据。
针对此需求,传统的方案通常会重新训练模型或者多次微调以适应每次的删除请求,不仅耗时且计算成本高,而且往往收效甚微。因为在深度学习的背景下,模型可能已经从大量数据中学习到复杂的标识,包含了各类别的特征,当需要从模型中移除掉特定类别的信息时,简单的删除原始数据并不够,模型往往还保留该类别的痕迹。
所以针对此需求,给大家介绍一种新的思路,一种基于SVD的无梯度类遗忘算法,主要解决了传统算法在计算成本和样本需求的局限性。
原文地址:Deep Unlearning: Fast and Efficient Gradient-free Class Forgetting | OpenReview
该论文的方法不算特别复杂,主要步骤如下:
• 首先用SVD估计保留空间和遗忘空间,分别对应于需要保留的类别和需要遗忘的类别的特征的激活;
• 通过计算并去除这些空间的共享信息以隔离出类别的判别特征空间;
• 然后更新模型权重以最大程度的抑制遗忘类别样本的激活,从而得到遗忘类别样本后的模型
• 通过各种分析,包括成员推断攻击、基于显著性的特征分析和混淆矩阵分析,证明模型的行为与重新训练但不包含遗忘类别样本的模型行为一致。
具体算法步骤如下:
论文也提供了示例代码,仓库地址:
https://github.com/sangamesh-kodge/class_forgetting
在demo.py中实现了整个流程:
1、数据生成和准备:
def get_dataset(train_samples_per_class, test_samples_per_class, std = 0.5):
print("-"*40)
print("Generating Dataset")
print("-"*40)
means = [(1.0,1.0),(-1.0,1.0), (-1.0,-1.0), (1.0,-1.0) ]
std_dev = (std,std)
data_list = []
target_list = []
for classes in range(4):
for sample in tqdm(range(train_samples_per_class), desc=f"Train class {classes}"):
data_list.append(torch.normal(mean= torch.tensor(means[classes]), std= torch.tensor(std_dev) ))
target_list.append(torch.tensor(classes).long())
data_tensor = torch.stack(data_list, 0)
target_tensor = torch.stack(target_list, 0)
train_data = (data_tensor, target_tensor)
data_list = []
target_list = []
for classes in range(4):
for sample in tqdm(range(test_samples_per_class), desc=f"Test class {classes}"):
data_list.append(torch.normal(mean= torch.tensor(means[classes]), std= torch.tensor(std) ))
target_list.append(torch.tensor(classes).long())
data_tensor = torch.stack(data_list, 0)
target_tensor = torch.stack(target_list, 0)
test_data = (data_tensor, target_tensor)
return train_data, test_data
2、网络模型定义:
class Linear5(nn.Module):
def __init__(self, in_feature=2, hidden_features=5, num_classes=4):
super(Linear5, self).__init__()
self.fc1 = nn.Linear(in_features=in_feature, out_features=hidden_features)
self.bn1 = nn.BatchNorm1d(hidden_features)
self.fc2 = nn.Linear(in_features=hidden_features, out_features=hidden_features)
self.bn2 = nn.BatchNorm1d(hidden_features)
self.fc3 = nn.Linear(in_features=hidden_features, out_features=hidden_features)
self.bn3 = nn.BatchNorm1d(hidden_features)
self.fc4 = nn.Linear(in_features=hidden_features, out_features=hidden_features)
self.bn4 = nn.BatchNorm1d(hidden_features)
self.fc5 = nn.Linear(in_features=hidden_features, out_features=num_classes)
def forward(self,x):
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.fc2(x)))
x = F.relu(self.bn3(self.fc3(x)))
x = F.relu(self.bn4(self.fc4(x)))
return F.log_softmax(self.fc5(x), 1)
def get_activations(self, x):
act = {"pre":OrderedDict(), "post":OrderedDict()}
act["pre"]["fc1"] = deepcopy(x.clone().detach().cpu().numpy())
x = self.fc1(x)
act["post"]["fc1"] = deepcopy(x.clone().detach().cpu().numpy())
x = F.relu(self.bn1(x))
act["pre"]["fc2"] = deepcopy(x.clone().detach().cpu().numpy())
x = self.fc2(x)
act["post"]["fc2"] = deepcopy(x.clone().detach().cpu().numpy())
x = F.relu(self.bn2(x))
act["pre"]["fc3"] = deepcopy(x.clone().detach().cpu().numpy())
x = self.fc3(x)
act["post"]["fc3"] = deepcopy(x.clone().detach().cpu().numpy())
x = F.relu(self.bn3(x))
act["pre"]["fc4"] = deepcopy(x.clone().detach().cpu().numpy())
x = self.fc4(x)
act["post"]["fc4"] = deepcopy(x.clone().detach().cpu().numpy())
x = F.relu(self.bn4(x))
act["pre"]["fc5"] = deepcopy(x.clone().detach().cpu().numpy())
x = self.fc5(x)
act["post"]["fc5"] = deepcopy(x.clone().detach().cpu().numpy())
return act
def project_weights(self, projection_mat_dict):
self.fc1.weight.data = torch.mm(projection_mat_dict["post"]["fc1"].transpose(0,1), torch.mm(self.fc1.weight.data, projection_mat_dict["pre"]["fc1"].transpose(0,1)))
self.fc1.bias.data = torch.mm(self.fc1.bias.data.unsqueeze(0), projection_mat_dict["post"]["fc1"]).squeeze(0)
self.fc2.weight.data = torch.mm(projection_mat_dict["post"]["fc2"].transpose(0,1), torch.mm(self.fc2.weight.data, projection_mat_dict["pre"]["fc2"].transpose(0,1)))
self.fc2.bias.data = torch.mm(self.fc2.bias.data.unsqueeze(0), projection_mat_dict["post"]["fc2"]).squeeze(0)
self.fc3.weight.data = torch.mm(projection_mat_dict["post"]["fc3"].transpose(0,1), torch.mm(self.fc3.weight.data, projection_mat_dict["pre"]["fc3"].transpose(0,1)))
self.fc3.bias.data = torch.mm(self.fc3.bias.data.unsqueeze(0), projection_mat_dict["post"]["fc3"]).squeeze(0)
self.fc4.weight.data = torch.mm(projection_mat_dict["post"]["fc4"].transpose(0,1), torch.mm(self.fc4.weight.data, projection_mat_dict["pre"]["fc4"].transpose(0,1)))
self.fc4.bias.data = torch.mm(self.fc4.bias.data.unsqueeze(0), projection_mat_dict["post"]["fc4"]).squeeze(0)
self.fc5.weight.data = torch.mm(self.fc5.weight.data, projection_mat_dict["pre"]["fc5"].transpose(0,1))
return
3、模型训练和评估:
def train(model, train_loader, device, optimizer, schedular):
model.train()
total_loss = 0
for data, target in tqdm(train_loader):
optimizer.zero_grad()
data, target = data.to(device), target.to(device)
log_prob = model(data)
loss = F.nll_loss(log_prob, target)
total_loss+= loss.detach().item()/len(train_loader.dataset)
loss.backward()
optimizer.step()
schedular.step(loss)
return total_loss
def test(model, test_loader, device):
model.eval()
total_loss = 0
correct = 0
for data, target in test_loader:
data, target = data.to(device), target.to(device)
log_prob = model(data)
pred = log_prob.argmax(dim=1, keepdim=True)
loss = F.nll_loss(log_prob, target)
total_loss+= loss.detach().item()/len(test_loader.dataset)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100*correct/len(test_loader.dataset)
return total_loss, acc
4、SVD遗忘空间的计算
# Compute Ur
for loc in retain_act.keys():
for key in retain_act[loc].keys():
activation = torch.Tensor(retain_act[loc][key]).to("cuda").transpose(0,1)
U,S,Vh = torch.linalg.svd(activation, full_matrices=False)
U = U.cpu().numpy()
S = S.cpu().numpy()
retain_mat_dict[loc][key] = U
retain_normalize_var_mat_dict[loc][key] = S**2 / (S**2).sum()
forget_data = data1[0][:train_samples_per_class]
forget_act = model.get_activations(forget_data.to(device))
forget_mat_dict= {"pre":OrderedDict(), "post":OrderedDict()}
forget_normalize_var_mat_dict = {"pre":OrderedDict(), "post":OrderedDict()}
# Compute Uf
for loc in forget_act.keys():
for key in forget_act[loc].keys():
activation = torch.Tensor(forget_act[loc][key]).to("cuda").transpose(0,1)
U,S,Vh = torch.linalg.svd(activation, full_matrices=False)
U = U.cpu().numpy()
S = S.cpu().numpy()
forget_mat_dict[loc][key] = U
forget_normalize_var_mat_dict[loc][key] = S**2 / (S**2).sum()
在main.py中也提供了其他不同的unlearn算法,如SSD、OUR、TARUN、SCRUB、GOEL和SALUN
SALUN:
def salun_unlearn(args, model, device, retain_loader, forget_loader, train_loader, test_loader, optimizer, epochs, **kwargs):
mask = get_salun_mask(args, model, device, forget_loader)
model.train()
train_loss= 0
train_forget_acc =100.0
steps = 0
# Random relabeling
forget_dataset = copy.deepcopy(forget_loader.dataset)
forget_dataset.update_labels(np.random.randint(0, args.num_classes, len(forget_dataset) ))
forget_dataset, _ = torch.utils.data.random_split(forget_dataset, [args.num_forget_samples, len(forget_dataset) - args.num_forget_samples])
#Subsample retain set.
retain_dataset = retain_loader.dataset
retain_dataset, _ = torch.utils.data.random_split(retain_dataset, [args.num_retain_samples, len(retain_dataset) - args.num_retain_samples])
train_dataset = torch.utils.data.ConcatDataset([forget_dataset,retain_dataset])
salun_train_loader= torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True)
max_steps = epochs*(len(train_dataset)//args.batch_size)
for epoch in range(epochs):
for data, target in salun_train_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.nll_loss(output, target)
optimizer.zero_grad()
loss.backward()
steps+=1
if mask:
for name, param in model.named_parameters():
if param.grad is not None:
param.grad *= mask[name]
optimizer.step()
train_loss += loss.detach().item()
if args.dry_run:
break
if args.dry_run:
break
return model
GOEL:
def goel_last_unlearn(args, model, device, retain_loader, forget_loader, train_loader, test_loader, optimizer, epochs, **kwargs):
train_loss= 0
train_forget_acc =100.0
steps = 0
#Subsample retain set.
retain_dataset = retain_loader.dataset
retain_dataset, _ = torch.utils.data.random_split(retain_dataset, [args.num_retain_samples, len(retain_dataset) - args.num_retain_samples])
goel_train_loader= torch.utils.data.DataLoader( retain_dataset, batch_size=args.batch_size, shuffle=True)
model.train()
model = resetFinalResnet(model, 1, reinit=args.goel_exact)
for epoch in range(epochs):
for data, target in goel_train_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.nll_loss(output, target)
optimizer.zero_grad()
loss.backward()
steps+=1
optimizer.step()
train_loss += loss.detach().item()
if args.dry_run:
break
if args.dry_run:
break
return model
SSD:
def ssd_unlearn(args, model, device, retain_loader, forget_loader, train_loader, test_loader, optimizer, **kwargs):
#Subsample dataset.
forget_dataset = forget_loader.dataset
forget_dataset, _ = torch.utils.data.random_split(forget_dataset, [args.num_forget_samples, len(forget_dataset) - args.num_forget_samples])
retain_dataset = retain_loader.dataset
retain_dataset, _ = torch.utils.data.random_split(retain_dataset, [args.num_retain_samples, len(retain_dataset) - args.num_retain_samples])
ssd_retain_loader= torch.utils.data.DataLoader( retain_dataset, batch_size=args.batch_size , shuffle=True)
train_dataset = torch.utils.data.ConcatDataset([forget_loader.dataset,retain_dataset])
ssd_train_loader= torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True)
# Hyperparameters
lower_bound = 1 # unused
exponent = 1 # unused
dampening_constant = args.ssd_lambda # Lambda from paper
selection_weighting = args.ssd_alpha # Alpha from paper
model.eval() # to ensure batch statistics do not change
# Calculation of the forget set importances
forget_importance = calc_importance(model, device, optimizer, forget_loader)
# Calculate the importances of D (see paper); this can also be done at any point before forgetting.
original_importance = calc_importance(model, device, optimizer,ssd_train_loader)
# Dampen selected parameters
with torch.no_grad():
for (n, p), (oimp_n, oimp), (fimp_n, fimp) in zip(
model.named_parameters(),
original_importance.items(),
forget_importance.items(),
):
# Synapse Selection with parameter alpha
oimp_norm = oimp.mul(selection_weighting)
locations = torch.where(fimp > oimp_norm)
# Synapse Dampening with parameter lambda
weight = ((oimp.mul(dampening_constant)).div(fimp)).pow(
exponent
)
update = weight[locations]
# Bound by 1 to prevent parameter values to increase.
min_locs = torch.where(update > lower_bound)
update[min_locs] = lower_bound
p[locations] = p[locations].mul(update)
return model
最终研究人员在ImageNet数据集上使用视觉变换器验证了他们的算法,结果显示,在仅损失约1.5%保留类精度的情况下,可以显著降低对遗忘类别的准确率,不到1%,并且在面对成员推断攻击时展现出韧性。与基线方法相比,新算法在ImageNet数据集上平均提高了1.38%的准确性,同时所需遗忘样本数量减少了10倍,此外,该算法在CIFAR-100数据集上的ResNet18架构下对更强的类别推断攻击也表现出了优于最佳基线1.8%的结果,感兴趣的同学可以跑一下demo代码。