EWC代码
持续学习相关代码可前往此处下载:ewc等算法完整实现
以下为ewc算法的pytorch实现
ewc.py代码
import sys, time
import numpy as np
import torch
from copy import deepcopy
import utils
class Appr(object):
""" Class implementing the Elastic Weight Consolidation approach described in http://arxiv.org/abs/1612.00796 """
def __init__(self, model, nepochs=100, sbatch=64, lr=0.05, lr_min=1e-4, lr_factor=3, lr_patience=5, clipgrad=100,
lamb=5000, args=None):
self.model = model
self.model_old = None
self.fisher = None
self.nepochs = nepochs
self.sbatch = sbatch
self.lr = lr
self.lr_min = lr_min
self.lr_factor = lr_factor
self.lr_patience = lr_patience
self.clipgrad = clipgrad
self.ce = torch.nn.CrossEntropyLoss()
self.optimizer = self._get_optimizer()
self.lamb = lamb # Grid search = [500,1000,2000,5000,10000,20000,50000]; best was 5000
if len(args.parameter) >= 1:
params = args.parameter.split(',')
print('Setting parameters to', params)
self.lamb = float(params[0])
return
def _get_optimizer(self, lr=None):
if lr is None: lr = self.lr
return torch.optim.SGD(self.model.parameters(), lr=lr)
def train(self, t, xtrain, ytrain, xvalid, yvalid):
best_loss = np.inf
best_model = utils.get_model(self.model)
lr = self.lr
patience = self.lr_patience
self.optimizer = self._get_optimizer(lr)
# Loop epochs
for e in range(self.nepochs):
# Train
clock0 = time.time()
self.train_epoch(t, xtrain, ytrain)
clock1 = time.time()
train_loss, train_acc = self.eval(t, xtrain, ytrain)
clock2 = time.time()
print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |'.format(
e + 1, 1000 * self.sbatch * (clock1 - clock0) / xtrain.size(0),
1000 * self.sbatch * (clock2 - clock1) / xtrain.size(0), train_loss, 100 * train_acc), end='')
# Valid
valid_loss, valid_acc = self.eval(t, xvalid, yvalid)
print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, 100 * valid_acc), end='')
# Adapt lr
if valid_loss < best_loss:
best_loss = valid_loss
best_model = utils.get_model(self.model)
patience = self.lr_patience
print(' *', end='')
else:
patience -= 1
if patience <= 0:
lr /= self.lr_factor
print(' lr={:.1e}'.format(lr), end='')
if lr < self.lr_min:
print()
break
patience = self.lr_patience
self.optimizer = self._get_optimizer(lr)
print()
# Restore best
utils.set_model_(self.model, best_model)
# Update old
self.model_old = deepcopy(self.model)
self.model_old.eval()
utils.freeze_model(self.model_old) # Freeze the weights
# Fisher ops
if t > 0:
fisher_old = {}
for n, _ in self.model.named_parameters():
fisher_old[n] = self.fisher[n].clone()
self.fisher = utils.fisher_matrix_diag(t, xtrain, ytrain, self.model, self.criterion)
if t > 0:
# Watch out! We do not want to keep t models (or fisher diagonals) in memory, therefore we have to merge fisher diagonals
for n, _ in self.model.named_parameters():
self.fisher[n] = (self.fisher[n] + fisher_old[n] * t) / (
t + 1) # Checked: it is better than the other option
# self.fisher[n]=0.5*(self.fisher[n]+fisher_old[n])
return
def train_epoch(self, t, x, y):
self.model.train()
r = np.arange(x.size(0))
np.random.shuffle(r)
r = torch.LongTensor(r).cuda()
# Loop batches
for i in range(0, len(r), self.sbatch):
if i + self.sbatch <= len(r):
b = r[i:i + self.sbatch]
else:
b = r[i:]
images = torch.autograd.Variable(x[b], volatile=False)
targets = torch.autograd.Variable(y[b], volatile=False)
# Forward current model
outputs = self.model.forward(images)
output = outputs[t]
loss = self.criterion(t, output, targets)
# Backward
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm(self.model.parameters(), self.clipgrad)
self.optimizer.step()
return
def eval(self, t, x, y):
total_loss = 0
total_acc = 0
total_num = 0
self.model.eval()
r = np.arange(x.size(0))
r = torch.LongTensor(r).cuda()
# Loop batches
for i in range(0, len(r), self.sbatch):
if i + self.sbatch <= len(r):
b = r[i:i + self.sbatch]
else:
b = r[i:]
images = torch.autograd.Variable(x[b], volatile=True)
targets = torch.autograd.Variable(y[b], volatile=True)
# Forward
outputs = self.model.forward(images)
output = outputs[t]
loss = self.criterion(t, output, targets)
_, pred = output.max(1)
hits = (pred == targets).float()
# Log
total_loss += loss.data.cpu().numpy() * len(b)
total_acc += hits.sum().data.cpu().numpy()
total_num += len(b)
return total_loss / total_num, total_acc / total_num
def criterion(self, t, output, targets):
# Regularization for all previous tasks
loss_reg = 0
if t > 0:
for (name, param), (_, param_old) in zip(self.model.named_parameters(), self.model_old.named_parameters()):
loss_reg += torch.sum(self.fisher[name] * (param_old - param).pow(2)) / 2
return self.ce(output, targets) + self.lamb * loss_reg
其中utils.py代码如下:
import os, sys
import numpy as np
from copy import deepcopy
import torch
from tqdm import tqdm
########################################################################################################################
def print_model_report(model):
print('-' * 100)
print(model)
print('Dimensions =', end=' ')
count = 0
for p in model.parameters():
print(p.size(), end=' ')
count += np.prod(p.size())
print()
print('Num parameters = %s' % (human_format(count)))
print('-' * 100)
return count
def human_format(num):
magnitude = 0
while abs(num) >= 1000:
magnitude += 1
num /= 1000.0
return '%.1f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude])
def print_optimizer_config(optim):
if optim is None:
print(optim)
else:
print(optim, '=', end=' ')
opt = optim.param_groups[0]
for n in opt.keys():
if not n.startswith('param'):
print(n + ':', opt[n], end=', ')
print()
return
########################################################################################################################
def get_model(model):
return deepcopy(model.state_dict())
def set_model_(model, state_dict):
model.load_state_dict(deepcopy(state_dict))
return
def freeze_model(model):
for param in model.parameters():
param.requires_grad = False
return
########################################################################################################################
def compute_conv_output_size(Lin, kernel_size, stride=1, padding=0, dilation=1):
return int(np.floor((Lin + 2 * padding - dilation * (kernel_size - 1) - 1) / float(stride) + 1))
########################################################################################################################
def compute_mean_std_dataset(dataset):
# dataset already put ToTensor
mean = 0
std = 0
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
for image, _ in loader:
mean += image.mean(3).mean(2)
mean /= len(dataset)
mean_expanded = mean.view(mean.size(0), mean.size(1), 1, 1).expand_as(image)
for image, _ in loader:
std += (image - mean_expanded).pow(2).sum(3).sum(2)
std = (std / (len(dataset) * image.size(2) * image.size(3) - 1)).sqrt()
return mean, std
########################################################################################################################
def fisher_matrix_diag(t, x, y, model, criterion, sbatch=20):
# Init
fisher = {}
for n, p in model.named_parameters():
fisher[n] = 0 * p.data
# Compute
model.train()
for i in tqdm(range(0, x.size(0), sbatch), desc='Fisher diagonal', ncols=100, ascii=True):
b = torch.LongTensor(np.arange(i, np.min([i + sbatch, x.size(0)]))).cuda()
images = torch.autograd.Variable(x[b], volatile=False)
target = torch.autograd.Variable(y[b], volatile=False)
# Forward and backward
model.zero_grad()
outputs = model.forward(images)
loss = criterion(t, outputs[t], target)
loss.backward()
# Get gradients
for n, p in model.named_parameters():
if p.grad is not None:
fisher[n] += sbatch * p.grad.data.pow(2)
# Mean
for n, _ in model.named_parameters():
fisher[n] = fisher[n] / x.size(0)
fisher[n] = torch.autograd.Variable(fisher[n], requires_grad=False)
return fisher
########################################################################################################################
def cross_entropy(outputs, targets, exp=1, size_average=True, eps=1e-5):
out = torch.nn.functional.softmax(outputs)
tar = torch.nn.functional.softmax(targets)
if exp != 1:
out = out.pow(exp)
out = out / out.sum(1).view(-1, 1).expand_as(out)
tar = tar.pow(exp)
tar = tar / tar.sum(1).view(-1, 1).expand_as(tar)
out = out + eps / out.size(1)
out = out / out.sum(1).view(-1, 1).expand_as(out)
ce = -(tar * out.log()).sum(1)
if size_average:
ce = ce.mean()
return ce
########################################################################################################################
def set_req_grad(layer, req_grad):
if hasattr(layer, 'weight'):
layer.weight.requires_grad = req_grad
if hasattr(layer, 'bias'):
layer.bias.requires_grad = req_grad
return
########################################################################################################################
def is_number(s):
try:
float(s)
return True
except ValueError:
pass
try:
import unicodedata
unicodedata.numeric(s)
return True
except (TypeError, ValueError):
pass
return False
########################################################################################################################