基本思想
自步学习(self-paced learning),从课程学习(curriculum learnig)衍生而来。基本思想是在学习过程中根据先易后难的顺序学习,自步学习将难易程度与样本的损失值进行关联,在模型训练过程中,判断样本损失值是否超过阈值,将样本划分为难学习以及易学习样本。
demo演示
由于实现自步学习依附的主体实验使用keras框架编写,因此练手的demo使用tensorflow2.3进行简单演示。写demo的难点在于keras训练过程以及损失函数的编写。
使用步骤
1.keras损失函数编写
关于损失函数计算,编写可以参考tensorflow官方文档,此处直接挪用。
这里给出一个通用的训练代码:
def train_step(X, y):
with tf.GradientTape() as tape:
pred = temp_model(X,training=True)
loss = loss_fn(y,pred)
grads = tape.gradient(loss,temp_model.trainable_weights)
opt.apply_gradients(zip(grads,temp_model.trainable_weights))
return loss
temp_model是自定义的模型,X为输入的数据集,y为输入的标签。loss_fn为损失函数,opt为optimizer,二分类任务,输入数据集为feature set,此处直接使用keras自带的Adam Optimizer。
损失函数为
class SPLLoss():
def __init__(self, n_samples=0):
self.threshold = 0.000000001
self.growing_factor = 1.35
self.v = np.zeros(n_samples)
def forward(self, input, target, index):
super_loss = tf.keras.losses.binary_crossentropy(input,target)
v = self.spl_loss(super_loss)
j = 0
for i in index:
self.v[i] = v[j].numpy()
j += 1
return tf.math.multiply(super_loss,v.numpy())/sum(v.numpy())
def increase_threshold(self):
self.threshold *= self.growing_factor
def spl_loss(self, super_loss):
v = super_loss < self.threshold
return v
loss_fn = SPLLoss(n_samples=len(x_train_triggered))
threshold以及growing_factor均为自步学习训练过程的超参,threshold为判断样本难易程度的阈值,growing_factor为每个epoch训练后,对threshold进行更新的参数。
2.模型训练
训练代码如下:
for epoch in range(epochs):
print('\nStart epoch',epoch)
for step,data in enumerate(train_A):
X,y = data
index = batch_idx[step]
loss = train_step(X, y, index)
if step %10 == 0:
print('self pacing learnig loss at step %d: %.5f'%(step,loss))
loss_fn.increase_threshold()
train_A为封装好的训练数据集,这里使用tf.data.Dataset.from_tensor_slices方法。
总结
本文仅仅简单介绍了自步学习的简单原理和在keras框架下的复现,对于torch框架以及tensorflow1.0框架并未尝试,感兴趣的可以自己实现,欢迎交流。
参考
1.损失函数代码:https://keras.io/guides/writing_a_training_loop_from_scratch/
2.训练过程损失函数值反向传播更新模型参数的代码:https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/GradientTape
3.自步学习经典文献:Kumar M P, Packer B, Koller D. Self-Paced Learning for Latent Variable Models. NIPS [C], Cambridge: MIT Press, 2010, 1: 2.