这是我的第296篇原创文章。
一、引言
机器学习模型的训练过程有动态和静态两类:
-
静态训练模型采用离线训练方式。一般只训练模型一次,然后长时间使用该模型。
-
动态训练模型采用增量训练(分为在线和离线)方式。数据会不断进入系统,通过不断地更新系统将这些数据整合到模型中。
二、应用场景
应用场景1:
做过机器学习的朋友都知道,有时候训练数据是很多的,几十万几百万也是常有的事。虽然几十万几百万只看记录数不算多,但是如果有几百个特征呢,那数据集是很恐怖的,如果存成numpy.float类型,那绝对是把内存吃爆。在超大数据集上,一般有这么几种方法:1. 对数据进行降维,2. 增量训练,使用流式或类似流式处理,3. 上大机器,高内存的,或者用spark集群。
应用场景2:
在实际业务场景会碰到这样一个问题,训练数据随着随着时间的累计越来越多。这样如果每一次训练都把所有的样本,训练一次,既浪费资源又耽误时间。所以,希望可以实现基于已有的模型,直接训练新的数据。比如,我用第一个月的数据训练好了一个模型,现在又来了第二个月的数据,以往的方式是把一月和二月的数据拼起来重新训练模型,现在希望的基于一月份已经获得的模型,直接训练二月份的数据。
总结一下,增量训练的主要用途有两个,一个是想办法利用全部的数据,另一个是想办法及时利用新的数据。
三、实现方式
在线动态增量训练:
在线学习的典型代表是用SGD优化的logistics regress
,先用数据初始化参数,线上来一个数据更新一次参数,虽然时间的推移,效果越来越好。这样就避免了离线更新模型的问题。
离线动态增量训练:
-
利用一月份的数据训练得到模型,并保存。
-
调用保存的模型,对其进行fit(其实,就是连续fit模型就行), 这样得到的模型与把一月二月数据合在一起训练得到的模型结果可能不一致,当前后两个数据分布不一致的时候,会使模型发生偏移。
四、离线动态增量训练实现
4.1 准备数据
data = pd.read_csv(r'Dataset.csv')
df = pd.DataFrame(data)
print(df)
# 提取目标变量和特征变量
target = 'target'
features = df.columns.drop(target)
x = df[features].values
y = df[target].values
df:
4.2 划分训练集和测试集
x_disorder, y_disorder = shuffle(x, y, random_state=1)
x_train, x_test, y_train, y_test = train_test_split(x_disorder, y_disorder, random_state=3)
4.3 训练集进一步划分为两部分
训练集做进一步拆解,一部分用于第一次训练,另一部分当作是新来的数据,用作第二次训练:
x_old_train = x_train[:90, :]
y_old_train = y_train[:90]
x_new_train = x_train[90:, :]
y_new_train = y_train[90:]
4.4 离线动态增量训练
先在一部分数据上fit训练,然后在此训练好模型的基础上继续fit:
model = RandomForestClassifier()
model = model.fit(x_old_train, y_old_train) # 先用一部分训练数据,训练模型,并保存
with open('old_version.pickle', 'wb') as f:
pickle.dump(model, f)
# 调用刚刚保存的模型,比如此时又来新的训练数据,这时模型继续训练,再预测
pickle_in = open('old_version.pickle', 'rb')
model = pickle.load(pickle_in)
model = model.fit(x_new_train, y_new_train)
先用一部分训练数据,训练模型,并保存,调用刚刚保存的模型,比如此时又来新的训练数据,这时模型继续训练。
4.5 直接一次性训练
直接训练所有数据:
all_model = RandomForestClassifier()
all_model = all_model.fit(x_train, y_train)
4.6 结果比较
两种结果存在一定的偏差,但是这样的好处是:模型参数不需要重新训练,只需要再以前的基础上继续训练节省时间。
作者简介:
读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。