Asif Rahman

Ensemble decision trees in Numba

2021-09-21

Representing ensembles of decision trees using numpy arrays with fast numba operations.

import math
import numpy as np
from numba import njit, prange


@njit
def take(X, inds):
    """Multidimensional indexing for numba"""
    n = len(X)
    y = np.zeros(n, dtype=X.dtype)
    for i in range(n):
        y[i] = X[i, inds[i]]
    return y


@njit
def next_node(node_id, value, thr):
    """A vectorized operation to find the next node in a binary tree given a 
    value and threshold.
    Args:
        node_id: int or array of the current node, root node_id = 0
    Return:
        new node id, left node when value <= thr, right node when value > thr
    """
    return (node_id << 1) + 1 + (1 * (thr < value))


@njit
def leaf(X, features, thresholds, reset_leaf_index=1):
    """Find the leaf node index along a decision path given a tree feature 
    indices, thresholds, and design matrix.
    Args:
        X: 2D design matrix of shape [nsamples, nfeatures]
        features: feature indices as dtype np.int32 of shape [internal_nodes]
        thresholds: split thresholds as dtype np.float64 of shape [internal_nodes]
        reset_leaf_index: returns leaf index initialized from 1
    """
    nsamples = len(X)
    node_id = np.zeros(nsamples, dtype=np.int64)
    n = len(features)
    depth = int(math.log(n+1)/math.log(2))
    internal_nodes = 2**(depth) - 1
    for i in range(depth):
        feature_ind = features[node_id]
        value = take(X, feature_ind)
        thr = thresholds[node_id]
        node_id = next_node(node_id, value, thr)
    if reset_leaf_index == 1:
        node_id = node_id - internal_nodes
    return node_id


@njit
def leaf_tokens(X, trees, nleaves_per_tree):
    """Tokenize a design matrix with leaf indices.
    Args:
        X: 2D design matrix of shape [nsamples, nfeatures]
        trees: 3D matrix deifining an ensemble of decision trees of shape 
            [ntrees, internal_nodes, 2]
        nleaves_per_tree: number of leaves in each decision tree, 2**depth
    """
    nsamples = X.shape[0]
    ntrees = trees.shape[0]
    leaves = np.zeros((nsamples,ntrees), dtype=np.int64)
    for i in prange(0,ntrees):
        features = trees[i,:,0].astype(np.int64)
        thresholds = trees[i,:,1]
        leaves[:,i] = leaf(X, features, thresholds) + nleaves_per_tree * i
    return leaves


@njit
def random_ensemble_decision_trees(ntrees, depth, nfeatures):
    """Generate a random ensemble of decision trees
    Args:
        ntrees: number of trees in ensemble
        depth: height of each tree
        nfeatures: number of features in design matrix
    Returns:
        trees: 3D matrix of shape [ntrees, internal_nodes, 2] where internal_nodes
            is the number of non-leaf nodes (including root node) calculated
            as 2^{depth} - 1 and the last dimension includes feature index and 
            splitting thresholds
    """
    internal_nodes = 2**(depth) - 1
    total_internal_nodes = ntrees * internal_nodes
    trees = np.zeros((total_internal_nodes, 2)) # [features, thresholds]
    trees[:,0] = np.random.randint(0, nfeatures, total_internal_nodes)
    trees[:,1] = np.random.rand(total_internal_nodes)
    trees = trees.reshape(ntrees, internal_nodes, -1)  # [ntrees, internal_nodes, 2]
    return trees


nsamples, nfeatures = 1000, 30
X = np.random.rand(nsamples,nfeatures)

ntrees = 1000
depth = 1
nleaves = 2**depth
total_nleaves = ntrees * nleaves
trees = random_ensemble_decision_trees(ntrees=ntrees, depth=depth, nfeatures=nfeatures)
leaves = leaf_tokens(X, trees, nleaves)