LightGBM 是一个梯度Boosting框架
要求
- LightGBM
- lightGBM包含在Gem中
- Red Datasets
- MNIST 数据集.
gem install red-datasets
gem install lightgbm
执行
require 'lightgbm'
require 'datasets'
train_mnist = Datasets::MNIST.new(type: :train)
test_mnist = Datasets::MNIST.new(type: :test)
train_x = train_mnist.map { |r| r.pixels.map { |i| i / 255.0 } }
train_y = train_mnist.map(&:label)
test_x = test_mnist.map { |r| r.pixels.map { |i| i / 255.0 } }
test_y = test_mnist.map(&:label)
params = {
task: :train,
boosting_type: :gbdt,
objective: :multiclass,
num_class: 10,
}
train_set = LightGBM::Dataset.new(train_x, label: train_y)
booster = LightGBM.train(params, train_set)
booster.save_model("mnist_lightgbm.txt")
# booster = LightGBM::Booster.new(model_file: 'mnist_lightgbm.txt')
result = booster.predict(test_x)
result.map!{|i| i.index(i.max)}
accuracy = test_y.zip(result).count{|i, j| i == j} / test_y.size.to_f
puts accuracy
0.9727
在短时间内可以得到高精度 ?