1.目前GBDT的算法比较好的库是xgboost。scikit-learn也可以。我使用的是sklearn需要安装numpy 和 sklearn,安装的时候会有各种包的依赖。建议安装 anaconda3,这样就全部安装好了。安装地址:https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/
2.python3的代码如下(和python2差不多)
3.训练速度特别快,100W的训练数据不到5分钟。
# -*- coding:utf-8 -*-
import numpy as np
import codecs
import pickle
from sklearn import ensemble
from sklearn import datasets
# 加载训练数据
def load_train_data():
X = []
y = []
with codecs.open("all_train.txt.108W.final.random", 'r', 'utf8') as reader:
for line in reader:
tokens = line.strip().split("\t")
y.append(tokens[3])
X.append(tokens[4:])
X = np.array(X)
y = np.array(y)
return X, y
# 加载预测数据
def load_pred_data():
ids = []
inputs = []
with codecs.open("data/pred.dat", 'r', 'utf8') as reader:
for line in reader:
tokens = line.strip().split("\t")
#print len(tokens)
if len(tokens) != 27:
continue
ids.append(tokens[3])
inputs.append(tokens[6:])
ids = np.array(ids)
inputs = np.array(inputs)
return ids, inputs
# 保存模型
def save_model(clf):
s = pickle.dumps(clf)
f = open("gbdt_classifier.model", "wb+")
f.write(s)
f.close()
# 训练GBDT模型
def train_gbdt():
X, y = load_train_data()
X = X.astype(np.float32)
labels, y = np.unique(y, return_inverse=True)
X_train, X_test = X[:90000], X[90000:]
y_train, y_test = y[:90000], y[90000:]
print (X[0:1])
print (y[0:1])
print ('+++++++++')
original_params = {'n_estimators': 1000, 'max_leaf_nodes': 4, 'max_depth': None, 'random_state': 2, 'min_samples_split': 5}
setting = {'learning_rate': 0.1, 'max_features': 2}
params = dict(original_params)
params.update(setting)
clf = ensemble.GradientBoostingClassifier(**params)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
y_predprob = clf.predict_proba(X_test)[:,1]
print ("Accuracy : %.4g" % clf.score(X_test, y_test))
save_model(clf)
# 基于GBDT模型预测
def predict_gbdt():
ids, inputs = load_pred_data()
f = open('model/gbdt_classifier.model','r')
s = f.read()
clf = pickle.loads(s)
f = open('result/pred_classifier.dat','w')
preds = clf.predict(inputs)
for i in range(0, len(preds)):
str_id = str(ids[i])
str_pred = str(preds[i])
str_rec = "%s\t%s\n" % (str_id, str_pred)
f.write(str_rec)
f.close()
# Start here.
train_gbdt()
#predict_gbdt()
GBDT的使用例子
于 2020-06-10 10:18:45 首次发布