Skip to content

Commit 52280c7

Browse files
author
David Foster
committed
initial commit
1 parent af6801d commit 52280c7

19 files changed

+655
-0
lines changed

Diff for: .DS_Store

12 KB
Binary file not shown.

Diff for: .Rbuildignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
^.*\.Rproj$
2+
^\.Rproj\.user$

Diff for: DESCRIPTION

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Package: xgboostExplainer
2+
Title: XGBoost Model Explainer
3+
Version: 0.1
4+
Authors@R: person("David", "Foster", email = "[email protected]", role = c("aut", "cre"))
5+
Description: XGBoost is a very successful machine learning package based on boosted trees. This package allows the predictions from an xgboost model to be split into the impact of each feature, making the model as transparent as a linear regression or decision tree.
6+
Depends: R (>= 3.4.1)
7+
Imports: data.table, xgboost, waterfalls, scales, ggplot2
8+
License: GPL-3
9+
Encoding: UTF-8
10+
LazyData: true
11+
RoxygenNote: 6.0.1.9000

Diff for: NAMESPACE

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Generated by roxygen2: do not edit by hand
2+
3+
export(buildExplainer)
4+
export(explainPredictions)
5+
export(showWaterfall)
6+
import(data.table)
7+
import(ggplot2)
8+
import(scales)
9+
import(waterfalls)
10+
import(xgboost)

Diff for: R/.DS_Store

6 KB
Binary file not shown.

Diff for: R/buildExplainer.R

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#' Step 1: Build an xgboostExplainer
2+
#'
3+
#' This function outputs an xgboostExplainer (a data table that stores the feature impact breakdown for each leaf of each tree in an xgboost model). It is required as input into the explainPredictions and showWaterfall functions.
4+
#' @param xgb.model A trained xgboost model
5+
#' @param trainingData A DMatrix of data used to train the model
6+
#' @param type The objective function of the model - either "binary" (for binary:logistic) or "regression" (for reg:linear)
7+
#' @param base_score Default 0.5. The base_score variable of the xgboost model.
8+
#' @return The XGBoost Explainer for the model. This is a data table where each row is a leaf of a tree in the xgboost model
9+
#' and each column is the impact of each feature on the prediction at the leaf.
10+
#'
11+
#' The leaf and tree columns uniquely identify the node.
12+
#'
13+
#' The sum of the other columns equals the prediction at the leaf (log-odds if binary response).
14+
#'
15+
#' The 'intercept' column is identical for all rows and is analogous to the intercept term in a linear / logistic regression.
16+
#'
17+
#' @export
18+
#' @import data.table
19+
#' @import xgboost
20+
#' @examples
21+
#' library(xgboost)
22+
#' library(xgboostExplainer)
23+
#'
24+
#' set.seed(123)
25+
#'
26+
#' data(agaricus.train, package='xgboost')
27+
#'
28+
#' X = as.matrix(agaricus.train$data)
29+
#' y = agaricus.train$label
30+
#'
31+
#' train_idx = 1:5000
32+
#'
33+
#' xgb.train.data <- xgb.DMatrix(X[train_idx,], label = y[train_idx])
34+
#' xgb.test.data <- xgb.DMatrix(X[-train_idx,])
35+
#'
36+
#' param <- list(objective = "binary:logistic")
37+
#' xgb.model <- xgboost(param =param, data = xgb.train.data, nrounds=3)
38+
#'
39+
#' col_names = colnames(X)
40+
#'
41+
#' pred.train = predict(xgb.model,X)
42+
#' nodes.train = predict(xgb.model,X,predleaf =TRUE)
43+
#' trees = xgb.model.dt.tree(col_names, model = xgb.model)
44+
#'
45+
#' #### The XGBoost Explainer
46+
#' explainer = buildExplainer(xgb.model,xgb.train.data, type="binary", base_score = 0.5)
47+
#' pred.breakdown = explainPredictions(xgb.model, explainer, xgb.test.data)
48+
#'
49+
#' showWaterfall(xgb.model, explainer, slice(xgb.test.data, as.integer(2)), type = "binary")
50+
#' showWaterfall(xgb.model, explainer, slice(xgb.test.data, as.integer(8)), type = "binary")
51+
52+
53+
buildExplainer = function(xgb.model, trainingData, type = "binary", base_score = 0.5){
54+
55+
col_names = attr(trainingData, ".Dimnames")[[2]]
56+
trees = xgb.model.dt.tree(col_names, model = xgb.model, n_first_tree = xgb.model$best_ntreelimit - 1)
57+
nodes.train = predict(xgb.model,xgb.train.data,predleaf =TRUE)
58+
59+
cat('\nSTEP 1 of 2')
60+
tree_list = getStatsForTrees(trees, nodes.train, type = type, base_score = base_score)
61+
cat('\n\nSTEP 2 of 2')
62+
explainer = buildExplainerFromTreeList(tree_list,col_names)
63+
64+
cat('\n\nDONE!\n')
65+
66+
return (explainer)
67+
}

Diff for: R/buildExplainerFromTreeList.R

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
#' @import data.table
3+
#' @import xgboost
4+
5+
buildExplainerFromTreeList = function(tree_list,col_names){
6+
7+
####accepts a list of trees and column names
8+
####outputs a data table, of the impact of each variable + intercept, for each leaf
9+
10+
tree_list_breakdown <- vector("list", length(col_names) + 3)
11+
names(tree_list_breakdown) = c(col_names,'intercept', 'leaf','tree')
12+
13+
num_trees = length(tree_list)
14+
15+
cat('\n\nGetting breakdown for each leaf of each tree...\n')
16+
pb <- txtProgressBar(style=3)
17+
18+
for (x in 1:num_trees){
19+
tree = tree_list[[x]]
20+
tree_breakdown = getTreeBreakdown(tree, col_names)
21+
tree_breakdown$tree = x - 1
22+
tree_list_breakdown = rbindlist(append(list(tree_list_breakdown),list(tree_breakdown)))
23+
setTxtProgressBar(pb, x / num_trees)
24+
}
25+
26+
return (tree_list_breakdown)
27+
28+
}

Diff for: R/explainPredictions.R

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#' Step 2: Get multiple prediction breakdowns from a trained xgboost model
2+
#'
3+
#' This function outputs the feature impact breakdown of a set of predictions made using an xgboost model.
4+
#' @param xgb.model A trained xgboost model
5+
#' @param explainer The output from the buildExplainer function, for this model
6+
#' @param data A DMatrix of data to be explained
7+
#' @return A data table where each row is an observation in the data and each column is the impact of each feature on the prediction.
8+
#'
9+
#' The sum of the row equals the prediction of the xgboost model for this observation (log-odds if binary response).
10+
#'
11+
#' @export
12+
#' @import data.table
13+
#' @import xgboost
14+
#' @examples
15+
#' library(xgboost)
16+
#' library(xgboostExplainer)
17+
#'
18+
#' set.seed(123)
19+
#'
20+
#' data(agaricus.train, package='xgboost')
21+
#'
22+
#' X = as.matrix(agaricus.train$data)
23+
#' y = agaricus.train$label
24+
#'
25+
#' train_idx = 1:5000
26+
#'
27+
#' xgb.train.data <- xgb.DMatrix(X[train_idx,], label = y[train_idx])
28+
#' xgb.test.data <- xgb.DMatrix(X[-train_idx,])
29+
#'
30+
#' param <- list(objective = "binary:logistic")
31+
#' xgb.model <- xgboost(param =param, data = xgb.train.data, nrounds=3)
32+
#'
33+
#' col_names = colnames(X)
34+
#'
35+
#' pred.train = predict(xgb.model,X)
36+
#' nodes.train = predict(xgb.model,X,predleaf =TRUE)
37+
#' trees = xgb.model.dt.tree(col_names, model = xgb.model)
38+
#'
39+
#' #### The XGBoost Explainer
40+
#' explainer = buildExplainer(xgb.model,xgb.train.data, type="binary", base_score = 0.5)
41+
#' pred.breakdown = explainPredictions(xgb.model, explainer, xgb.test.data)
42+
#'
43+
#' showWaterfall(xgb.model, explainer, slice(xgb.test.data, as.integer(2)), type = "binary")
44+
#' showWaterfall(xgb.model, explainer, slice(xgb.test.data, as.integer(8)), type = "binary")
45+
46+
explainPredictions = function(xgb.model, explainer ,data){
47+
48+
#Accepts data table of the breakdown for each leaf of each tree and the node matrix
49+
#Returns the breakdown for each prediction as a data table
50+
51+
nodes = predict(xgb.model,data,predleaf =TRUE)
52+
53+
colnames = names(explainer)[1:(ncol(explainer)-2)]
54+
55+
preds_breakdown = data.table(matrix(0,nrow = nrow(nodes), ncol = length(colnames)))
56+
setnames(preds_breakdown, colnames)
57+
58+
num_trees = ncol(nodes)
59+
60+
cat('\n\nExtracting the breakdown of each prediction...\n')
61+
pb <- txtProgressBar(style=3)
62+
for (x in 1:num_trees){
63+
nodes_for_tree = nodes[,x]
64+
tree_breakdown = explainer[tree==x-1]
65+
66+
preds_breakdown_for_tree = tree_breakdown[match(nodes_for_tree, tree_breakdown$leaf),]
67+
preds_breakdown = preds_breakdown + preds_breakdown_for_tree[,colnames,with=FALSE]
68+
69+
setTxtProgressBar(pb, x / num_trees)
70+
}
71+
72+
cat('\n\nDONE!\n')
73+
74+
return (preds_breakdown)
75+
76+
}

Diff for: R/findLeaves.R

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
#' @import data.table
3+
#' @import xgboost
4+
findLeaves = function(tree, currentnode){
5+
6+
if (tree[currentnode,'Feature']=='Leaf'){
7+
leaves = currentnode
8+
}else{
9+
leftnode = tree[currentnode,Yes]
10+
rightnode = tree[currentnode,No]
11+
leaves = c(findLeaves(tree,'leftnode',with=FALSE),findLeaves(tree,'rightnode',with=FALSE))
12+
}
13+
14+
return (sort(leaves))
15+
16+
17+
}

Diff for: R/findPath.R

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
#' @import data.table
3+
#' @import xgboost
4+
findPath = function(tree, currentnode, path = c()){
5+
6+
#accepts a tree data table, and the node to reach
7+
#path is used in the recursive function - do not set this
8+
9+
while(currentnode>0){
10+
path = c(path,currentnode)
11+
currentlabel = tree[Node==currentnode,ID]
12+
currentnode = c(tree[Yes==currentlabel,Node],tree[No==currentlabel,Node])
13+
}
14+
return (sort(c(path,0)))
15+
16+
}
17+

Diff for: R/getLeafBreakdown.R

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
2+
#' @import data.table
3+
#' @import xgboost
4+
getLeafBreakdown = function(tree,leaf,col_names){
5+
6+
####accepts a tree, the leaf id to breakdown and column names
7+
####outputs a list of the impact of each variable + intercept
8+
9+
impacts = as.list(rep(0,length(col_names)))
10+
names(impacts) = col_names
11+
12+
path = findPath(tree,leaf)
13+
reduced_tree = tree[Node %in% path,.(Feature,uplift_weight)]
14+
15+
impacts$intercept=reduced_tree[1,uplift_weight]
16+
reduced_tree[,uplift_weight:=shift(uplift_weight,type='lead')]
17+
18+
tmp = reduced_tree[,.(sum=sum(uplift_weight)),by=Feature]
19+
tmp = tmp[-nrow(tmp)]
20+
impacts[tmp[,Feature]]=tmp[,sum]
21+
22+
return (impacts)
23+
}
24+
25+

Diff for: R/getStatsForTrees.R

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
2+
#' @import data.table
3+
#' @import xgboost
4+
getStatsForTrees = function(trees, nodes.train, type = "binary", base_score = 0.5){
5+
#Accepts data table of tree (the output of xgb.model.dt.tree)
6+
#Returns a list of tree, with the stats filled in
7+
8+
tree_list = copy(trees)
9+
tree_list[,leaf := Feature == 'Leaf']
10+
tree_list[,H:=Cover]
11+
12+
non.leaves = which(tree_list[,leaf]==F)
13+
14+
15+
# The default cover (H) seems to lose precision so this loop recalculates it for each node of each tree
16+
cat('\n\nRecalculating the cover for each non-leaf... \n')
17+
pb <- txtProgressBar(style=3)
18+
j = 0
19+
for (i in rev(non.leaves)){
20+
left = tree_list[i,Yes]
21+
right = tree_list[i,No]
22+
tree_list[i,H:=tree_list[ID==left,H] + tree_list[ID==right,H]]
23+
j=j+1
24+
setTxtProgressBar(pb, j / length(non.leaves))
25+
}
26+
27+
28+
if (type == 'regression'){
29+
base_weight = base_score
30+
} else{
31+
base_weight = log(base_score / (1-base_score))
32+
}
33+
34+
tree_list[leaf==T,weight:=base_weight + Quality]
35+
36+
tree_list[,previous_weight:=base_weight]
37+
tree_list[1,previous_weight:=0]
38+
39+
tree_list[leaf==T,G:=-weight*H]
40+
41+
tree_list = split(tree_list,as.factor(tree_list$Tree))
42+
num_tree_list = length(tree_list)
43+
treenums = as.character(0:(num_tree_list-1))
44+
t = 0
45+
cat('\n\nFinding the stats for the xgboost trees...\n')
46+
pb <- txtProgressBar(style=3)
47+
for (tree in tree_list){
48+
t=t+1
49+
num_nodes = nrow(tree)
50+
non_leaf_rows = rev(which(tree[,leaf]==F))
51+
for (r in non_leaf_rows){
52+
left = tree[r,Yes]
53+
right = tree[r,No]
54+
leftG = tree[ID==left,G]
55+
rightG = tree[ID==right,G]
56+
57+
tree[r,G:=leftG+rightG]
58+
w=tree[r,-G/H]
59+
60+
tree[r,weight:=w]
61+
tree[ID==left,previous_weight:=w]
62+
tree[ID==right,previous_weight:=w]
63+
}
64+
65+
tree[,uplift_weight:=weight-previous_weight]
66+
setTxtProgressBar(pb, t / num_tree_list)
67+
}
68+
69+
return (tree_list)
70+
}

Diff for: R/getTreeBreakdown.R

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
2+
#' @import data.table
3+
#' @import xgboost
4+
getTreeBreakdown = function(tree, col_names){
5+
6+
####accepts a tree (data table), and column names
7+
####outputs a data table, of the impact of each variable + intercept, for each leaf
8+
9+
10+
11+
tree_breakdown <- vector("list", length(col_names) + 2)
12+
names(tree_breakdown) = c(col_names,'intercept','leaf')
13+
14+
leaves = tree[leaf==T, Node]
15+
16+
for (leaf in leaves){
17+
18+
leaf_breakdown = getLeafBreakdown(tree,leaf,col_names)
19+
leaf_breakdown$leaf = leaf
20+
tree_breakdown = rbindlist(append(list(tree_breakdown),list(leaf_breakdown)))
21+
}
22+
23+
return (tree_breakdown)
24+
}

0 commit comments

Comments
 (0)