Linear Model Trees

Introduction

Linear model trees combine linear models and decision trees to create a hybrid model that produces better predictions and leads to better insights than either model alone. A linear model tree is simply a decision tree with linear models at its nodes. This can be seen as a piecewise linear model with knots learned via a decision tree algorithm. LMTs can be used for regression problems (e.g. with linear regression models instead of population means) or classification problems (e.g. with logistic regression instead of population modes)

在这里插入图片描述

Above is a heuristic chart of machine learning models along the axes of accuracy and interpretability. The upper-right quadrant is the best, with both high performance and high interpretability. This chart shows that LMTs are both highly interpretable and highly performant.

在这里插入图片描述
GBT’s variable importance attribute tells us that weight is the most important feature, followed by horsepower, acceleration, displacement, and model_year, which are all similar. Unfortunately, GBT does not tell us anything about the numerical magnitude or sign of their impact, nor relationship of these features.
The LMT produces just 2 splits, for a total of 3 leaf nodes. It splits first at horsepower = 78, and for horsepower >= 78 it splits at horsepower = 97. We will call the three subpopulations low power, medium power, and high power.
Inspecting the weights from the linear model tree gives us a very different understanding of what affects fuel efficiency than we got from the other models. While there are some commonalities across the different subpopulations that our LMT has identified, we also see some significant differences.

在这里插入图片描述

LMT vs. Others

Below we will demonstrate LMTs with the open source auto-mpg dataset. The auto-mpg dataset concerns the fuel consumption of 398 vehicles from the 1970s and early 1980s. We will predict fuel consumption (mpg) based on vehicle weight, model year, horsepower, acceleration, engine displacement and number of cylinders. The jupyter notebook linked at the bottom of this post contains the full exploration of this data and model building. The results will be summarized here.
在这里插入图片描述
The above table shows the performance of four different algorithms at the task of predicting mpg on this dataset. It is no surprise that Gradient Boosting Trees (GBT) performs best, as this algorithm often produces the best predictive performance. However, LMT performs very nearly as well, and as we will see below, it has other benefits. Linear regression and a single decision tree perform poorly compared to the other two models

LMT vs. GBT

GBT did a great job in predictive performance with MSE. The next question is what drives gas mileage of the cars? We dive into this with variable importance on the GBT model and get the following:

Summary of LMT benefits

For my final words on Linear Model Trees, here is a summary of their benefits:

  • LMTs are powerfully interpretable. Get insights into linear and non-linear relationships in your data. This can lead to other modeling hypotheses or product ideas.
  • LMTs identify subpopulations with different behavior.
  • LMTs can easily identify and utilize linear relationships. Tree-based models (including Random Forests and Gradient Boosting Trees) take a lot of effort to learn a line because they fit a piecewise constant model by predicting the average of all observations in each leaf node. Therefore they require many splits to approximate a linear relationship.
  • Overfitting (high variance) can be avoided by using cross-validation to optimize the minimum node size and maximum tree depth.
  • LMTs can work well with a modest amount of data (compared to many nonlinear models)
  • LMTs often produce simple models that are easy to implement in a production system, even if that system is not written in the same language that you use for modeling

LMT implementation

import math
import numpy as np

class LinearModelTree:
  def __init__(self, min_node_size, node_model_fit_func, min_split_improvement=0):
    self.min_node_size = min_node_size
    self.node_fit_func = node_model_fit_func
    self.min_split_improvement = min_split_improvement
    self.root = None

  def linear_model_predict(self, X, y):
    linear_model = self.node_fit_func(X, y)
    predictions = linear_model.predict(X)
    if np.isnan(predictions).any():
      print('got nan predictions')
    return predictions, linear_model

  def build_tree(self, X, y):
    self.root = Node.build_node_recursive(self, X, y)

  def predict(self, X):
    predicted = []
    for i in range(X.shape[0]):
      predicted.append(self.predict_one(X[i, :]))
    return predicted

  def predict_one(self, X):
    return self.root.predict_one(X)

  def node_count(self):
    return self.root.node_count()

  def serialize(self):
    return self.root.serialize()


class Node:
  def __init__(self, feature_idx, pivot_value, linear_model, raw_count):
    self.feature_idx = feature_idx
    self.pivot_value = pivot_value
    self.linear_model = linear_model
    self.raw_count = raw_count
    self.left = None
    self.right = None

  def node_count(self):
    if self.feature_idx is not None:
      return 1 + self.left.node_count() + self.right.node_count()
    else:
      return 1

  def predict_one(self, x):
    local_value = self.linear_model.predict(x.reshape(1, -1))[0]
    if self.feature_idx is None:
      return local_value
    else:
      if x[self.feature_idx] < self.pivot_value:
        child_value = self.left.predict_one(x)
      else:
        child_value = self.right.predict_one(x)
    return child_value + local_value

  @staticmethod
  def build_node_recursive(tree, X, y):
    feature_idx, pivot_value, linear_model, residuals = Node.find_best_split(tree, X, y)
    node = Node(feature_idx, pivot_value, linear_model, X.shape[0])
    if feature_idx is not None:
      left_X, left_residuals, right_X, right_residuals = Node.split_on_pivot(X, residuals, feature_idx, pivot_value)
      node.left = Node.build_node_recursive(tree, left_X, left_residuals)
      node.right = Node.build_node_recursive(tree, right_X, right_residuals)

    return node

  @staticmethod
  def split_on_pivot(X, y, feature_idx, pivot_value):
    # sort by column feature_idx
    sorting_indices = X[:, feature_idx].argsort()
    sorted_X = X[sorting_indices]
    sorted_y = y[sorting_indices]
    pivot_idx = np.argmax(sorted_X[:, feature_idx] >= pivot_value)
    return sorted_X[:pivot_idx, :], sorted_y[:pivot_idx], sorted_X[pivot_idx:, :], sorted_y[pivot_idx:]

  @staticmethod
  def find_best_split(tree, X, y):
    predictions, linear_model = tree.linear_model_predict(X, y)
    residuals = y - predictions
    row_count = X.shape[0]
    sse = (residuals ** 2).sum()
    sum_residual = residuals.sum()
    best_sse = sse
    best_feature = None
    best_feature_pivot = None
    for feature_idx in range(X.shape[1]):
      # sort by column feature_idx
      sorting_indices = X[:, feature_idx].argsort()
      sorted_X = X[sorting_indices]
      sorted_resid = residuals[sorting_indices]
      sum_residual_left = 0
      sum_residual_right = sum_residual
      sum_squared_left = 0
      sum_squared_right = sse
      count_left = 0
      count_right = row_count
      pivot_idx = 0
      while count_right >= tree.min_node_size:
        # advance our pivot
        raw_residual = sorted_resid[pivot_idx]
        sum_residual_left += raw_residual
        sum_residual_right -= raw_residual
        sum_squared_left += raw_residual*raw_residual
        sum_squared_right -= raw_residual*raw_residual
        count_left += 1
        count_right -= 1
        pivot_idx += 1
        if count_left >= tree.min_node_size and count_right >= tree.min_node_size:
          # consider a split
          # compute rmse from sum, sum_squared, and n: rmse = 1/n * sqrt ( n*sum_x2 - (sum_x)^2 )
          rmse_left = math.sqrt((count_left * sum_squared_left) - (sum_residual_left * sum_residual_left)) / count_left
          sse_left = rmse_left * rmse_left * count_left
          rmse_right = math.sqrt((count_right * sum_squared_right) - (sum_residual_right * sum_residual_right)) / count_right
          sse_right = rmse_right * rmse_right * count_right
          split_sse = sse_left + sse_right

          if (split_sse < best_sse and sse - split_sse > tree.min_split_improvement and
              # only if the value is different than the last value
              (count_left <= 1 or sorted_X[pivot_idx, feature_idx] != sorted_X[pivot_idx - 1, feature_idx])):
            best_sse = split_sse
            best_feature = feature_idx
            best_feature_pivot = sorted_X[pivot_idx, feature_idx]

    return best_feature, best_feature_pivot, linear_model, residuals

  def serialize(self, prefix='T'):
    if self.feature_idx is not None:
      self_str = f',rc:{self.raw_count}, f: {self.feature_idx}, v:{self.pivot_value}'
      return "\n" + prefix + (self_str +
                              self.left.serialize(prefix + 'L') +
                              self.right.serialize(prefix + 'R')
                              )
    else:
      self_str = f',rc:{self.raw_count},f:_,v:_,int:{self.linear_model.intercept_},coef:{self.linear_model.coef_}'
      return "\n" + prefix + self_str

reference

https://medium.com/convoy-tech/the-best-of-both-worlds-linear-model-trees-7c9ce139767d

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值