Introduction
Linear model trees combine linear models and decision trees to create a hybrid model that produces better predictions and leads to better insights than either model alone. A linear model tree is simply a decision tree with linear models at its nodes. This can be seen as a piecewise linear model with knots learned via a decision tree algorithm. LMTs can be used for regression problems (e.g. with linear regression models instead of population means) or classification problems (e.g. with logistic regression instead of population modes)
Above is a heuristic chart of machine learning models along the axes of accuracy and interpretability. The upper-right quadrant is the best, with both high performance and high interpretability. This chart shows that LMTs are both highly interpretable and highly performant.
GBT’s variable importance attribute tells us that weight is the most important feature, followed by horsepower, acceleration, displacement, and model_year, which are all similar. Unfortunately, GBT does not tell us anything about the numerical magnitude or sign of their impact, nor relationship of these features.
The LMT produces just 2 splits, for a total of 3 leaf nodes. It splits first at horsepower = 78, and for horsepower >= 78 it splits at horsepower = 97. We will call the three subpopulations low power, medium power, and high power.
Inspecting the weights from the linear model tree gives us a very different understanding of what affects fuel efficiency than we got from the other models. While there are some commonalities across the different subpopulations that our LMT has identified, we also see some significant differences.
LMT vs. Others
Below we will demonstrate LMTs with the open source auto-mpg dataset. The auto-mpg dataset concerns the fuel consumption of 398 vehicles from the 1970s and early 1980s. We will predict fuel consumption (mpg) based on vehicle weight, model year, horsepower, acceleration, engine displacement and number of cylinders. The jupyter notebook linked at the bottom of this post contains the full exploration of this data and model building. The results will be summarized here.
The above table shows the performance of four different algorithms at the task of predicting mpg on this dataset. It is no surprise that Gradient Boosting Trees (GBT) performs best, as this algorithm often produces the best predictive performance. However, LMT performs very nearly as well, and as we will see below, it has other benefits. Linear regression and a single decision tree perform poorly compared to the other two models
LMT vs. GBT
GBT did a great job in predictive performance with MSE. The next question is what drives gas mileage of the cars? We dive into this with variable importance on the GBT model and get the following:
Summary of LMT benefits
For my final words on Linear Model Trees, here is a summary of their benefits:
- LMTs are powerfully interpretable. Get insights into linear and non-linear relationships in your data. This can lead to other modeling hypotheses or product ideas.
- LMTs identify subpopulations with different behavior.
- LMTs can easily identify and utilize linear relationships. Tree-based models (including Random Forests and Gradient Boosting Trees) take a lot of effort to learn a line because they fit a piecewise constant model by predicting the average of all observations in each leaf node. Therefore they require many splits to approximate a linear relationship.
- Overfitting (high variance) can be avoided by using cross-validation to optimize the minimum node size and maximum tree depth.
- LMTs can work well with a modest amount of data (compared to many nonlinear models)
- LMTs often produce simple models that are easy to implement in a production system, even if that system is not written in the same language that you use for modeling
LMT implementation
import math
import numpy as np
class LinearModelTree:
def __init__(self, min_node_size, node_model_fit_func, min_split_improvement=0):
self.min_node_size = min_node_size
self.node_fit_func = node_model_fit_func
self.min_split_improvement = min_split_improvement
self.root = None
def linear_model_predict(self, X, y):
linear_model = self.node_fit_func(X, y)
predictions = linear_model.predict(X)
if np.isnan(predictions).any():
print('got nan predictions')
return predictions, linear_model
def build_tree(self, X, y):
self.root = Node.build_node_recursive(self, X, y)
def predict(self, X):
predicted = []
for i in range(X.shape[0]):
predicted.append(self.predict_one(X[i, :]))
return predicted
def predict_one(self, X):
return self.root.predict_one(X)
def node_count(self):
return self.root.node_count()
def serialize(self):
return self.root.serialize()
class Node:
def __init__(self, feature_idx, pivot_value, linear_model, raw_count):
self.feature_idx = feature_idx
self.pivot_value = pivot_value
self.linear_model = linear_model
self.raw_count = raw_count
self.left = None
self.right = None
def node_count(self):
if self.feature_idx is not None:
return 1 + self.left.node_count() + self.right.node_count()
else:
return 1
def predict_one(self, x):
local_value = self.linear_model.predict(x.reshape(1, -1))[0]
if self.feature_idx is None:
return local_value
else:
if x[self.feature_idx] < self.pivot_value:
child_value = self.left.predict_one(x)
else:
child_value = self.right.predict_one(x)
return child_value + local_value
@staticmethod
def build_node_recursive(tree, X, y):
feature_idx, pivot_value, linear_model, residuals = Node.find_best_split(tree, X, y)
node = Node(feature_idx, pivot_value, linear_model, X.shape[0])
if feature_idx is not None:
left_X, left_residuals, right_X, right_residuals = Node.split_on_pivot(X, residuals, feature_idx, pivot_value)
node.left = Node.build_node_recursive(tree, left_X, left_residuals)
node.right = Node.build_node_recursive(tree, right_X, right_residuals)
return node
@staticmethod
def split_on_pivot(X, y, feature_idx, pivot_value):
# sort by column feature_idx
sorting_indices = X[:, feature_idx].argsort()
sorted_X = X[sorting_indices]
sorted_y = y[sorting_indices]
pivot_idx = np.argmax(sorted_X[:, feature_idx] >= pivot_value)
return sorted_X[:pivot_idx, :], sorted_y[:pivot_idx], sorted_X[pivot_idx:, :], sorted_y[pivot_idx:]
@staticmethod
def find_best_split(tree, X, y):
predictions, linear_model = tree.linear_model_predict(X, y)
residuals = y - predictions
row_count = X.shape[0]
sse = (residuals ** 2).sum()
sum_residual = residuals.sum()
best_sse = sse
best_feature = None
best_feature_pivot = None
for feature_idx in range(X.shape[1]):
# sort by column feature_idx
sorting_indices = X[:, feature_idx].argsort()
sorted_X = X[sorting_indices]
sorted_resid = residuals[sorting_indices]
sum_residual_left = 0
sum_residual_right = sum_residual
sum_squared_left = 0
sum_squared_right = sse
count_left = 0
count_right = row_count
pivot_idx = 0
while count_right >= tree.min_node_size:
# advance our pivot
raw_residual = sorted_resid[pivot_idx]
sum_residual_left += raw_residual
sum_residual_right -= raw_residual
sum_squared_left += raw_residual*raw_residual
sum_squared_right -= raw_residual*raw_residual
count_left += 1
count_right -= 1
pivot_idx += 1
if count_left >= tree.min_node_size and count_right >= tree.min_node_size:
# consider a split
# compute rmse from sum, sum_squared, and n: rmse = 1/n * sqrt ( n*sum_x2 - (sum_x)^2 )
rmse_left = math.sqrt((count_left * sum_squared_left) - (sum_residual_left * sum_residual_left)) / count_left
sse_left = rmse_left * rmse_left * count_left
rmse_right = math.sqrt((count_right * sum_squared_right) - (sum_residual_right * sum_residual_right)) / count_right
sse_right = rmse_right * rmse_right * count_right
split_sse = sse_left + sse_right
if (split_sse < best_sse and sse - split_sse > tree.min_split_improvement and
# only if the value is different than the last value
(count_left <= 1 or sorted_X[pivot_idx, feature_idx] != sorted_X[pivot_idx - 1, feature_idx])):
best_sse = split_sse
best_feature = feature_idx
best_feature_pivot = sorted_X[pivot_idx, feature_idx]
return best_feature, best_feature_pivot, linear_model, residuals
def serialize(self, prefix='T'):
if self.feature_idx is not None:
self_str = f',rc:{self.raw_count}, f: {self.feature_idx}, v:{self.pivot_value}'
return "\n" + prefix + (self_str +
self.left.serialize(prefix + 'L') +
self.right.serialize(prefix + 'R')
)
else:
self_str = f',rc:{self.raw_count},f:_,v:_,int:{self.linear_model.intercept_},coef:{self.linear_model.coef_}'
return "\n" + prefix + self_str
reference
https://medium.com/convoy-tech/the-best-of-both-worlds-linear-model-trees-7c9ce139767d