Skip to content

Commit 84fccec

Browse files
committed
Use output-last SHAP C API shape
1 parent 383e8cf commit 84fccec

4 files changed

Lines changed: 21 additions & 29 deletions

File tree

include/xgboost/c_api.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,9 +1243,6 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, DMatrixHandle dmat
12431243
* Beginning iteration.
12441244
* "iteration_end": int
12451245
* End iteration. Set to 0 to use all trees.
1246-
* "strict_shape": bool
1247-
* Whether output shapes should include the output-group dimension even when
1248-
* there is only one output group.
12491246
*
12501247
* @param out_values_shape Shape of feature SHAP values (copy before use).
12511248
* @param out_values_dim Dimension of feature SHAP values.

python-package/xgboost/interpret.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def _capi_shap_values(
7676
"algorithm": "auto",
7777
"iteration_begin": int(iteration_range[0]),
7878
"iteration_end": int(iteration_range[1]),
79-
"strict_shape": False,
8079
}
8180
_check_call(
8281
_LIB.XGBoosterInterpretShapValues(
@@ -92,10 +91,12 @@ def _capi_shap_values(
9291
ctypes.byref(bias),
9392
)
9493
)
95-
return (
96-
_prediction_output(values_shape, values_dim, values, False),
97-
_prediction_output(bias_shape, bias_dim, bias, False),
98-
)
94+
values_out = _prediction_output(values_shape, values_dim, values, False)
95+
bias_out = _prediction_output(bias_shape, bias_dim, bias, False)
96+
if values_out.shape[-1] == 1:
97+
values_out = values_out[..., 0]
98+
bias_out = bias_out[..., 0]
99+
return values_out, bias_out
99100

100101

101102
def shap_values( # pylint: disable=too-many-arguments

src/c_api/c_api.cc

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,7 +1389,6 @@ XGB_DLL int XGBoosterInterpretShapValues(
13891389
auto p_m = *static_cast<std::shared_ptr<DMatrix> *>(dmat);
13901390
auto iteration_begin = OptionalArg<Integer>(config, "iteration_begin", Integer::Int{0});
13911391
auto iteration_end = OptionalArg<Integer>(config, "iteration_end", Integer::Int{0});
1392-
bool strict_shape = OptionalArg<Boolean>(config, "strict_shape", false);
13931392

13941393
learner->Predict(p_m, false, &entry.predictions, iteration_begin, iteration_end, false, false,
13951394
true, false, false);
@@ -1407,29 +1406,23 @@ XGB_DLL int XGBoosterInterpretShapValues(
14071406
for (std::size_t row = 0; row < rows; ++row) {
14081407
for (std::size_t group = 0; group < groups; ++group) {
14091408
std::size_t contrib_offset = row * groups * (cols + 1) + group * (cols + 1);
1410-
std::size_t value_offset = row * groups * cols + group * cols;
1411-
std::copy_n(contribs.cbegin() + contrib_offset, cols, values.begin() + value_offset);
1409+
for (std::size_t col = 0; col < cols; ++col) {
1410+
std::size_t value_offset = row * cols * groups + col * groups + group;
1411+
values[value_offset] = contribs[contrib_offset + col];
1412+
}
14121413
bias[row * groups + group] = contribs[contrib_offset + cols];
14131414
}
14141415
}
14151416

14161417
auto &values_shape = local.prediction_shape;
14171418
auto &bias_shape = local.prediction_shape_1;
1418-
if (groups == 1 && !strict_shape) {
1419-
values_shape.resize(2);
1420-
values_shape[0] = rows;
1421-
values_shape[1] = cols;
1422-
bias_shape.resize(1);
1423-
bias_shape[0] = rows;
1424-
} else {
1425-
values_shape.resize(3);
1426-
values_shape[0] = rows;
1427-
values_shape[1] = groups;
1428-
values_shape[2] = cols;
1429-
bias_shape.resize(2);
1430-
bias_shape[0] = rows;
1431-
bias_shape[1] = groups;
1432-
}
1419+
values_shape.resize(3);
1420+
values_shape[0] = rows;
1421+
values_shape[1] = cols;
1422+
values_shape[2] = groups;
1423+
bias_shape.resize(2);
1424+
bias_shape[0] = rows;
1425+
bias_shape[1] = groups;
14331426

14341427
xgboost_CHECK_C_ARG_PTR(out_values_dim);
14351428
xgboost_CHECK_C_ARG_PTR(out_values_shape);

tests/cpp/c_api/test_c_api.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,6 @@ TEST(CAPI, InterpretShapValues) {
742742
shap_config["algorithm"] = String{"auto"};
743743
shap_config["iteration_begin"] = Integer{0};
744744
shap_config["iteration_end"] = Integer{0};
745-
shap_config["strict_shape"] = Boolean{false};
746745
auto sshap_config = Json::Dump(shap_config);
747746

748747
bst_ulong const *values_shape{nullptr};
@@ -755,11 +754,13 @@ TEST(CAPI, InterpretShapValues) {
755754
&values_shape, &values_dim, &values, &bias_shape,
756755
&bias_dim, &bias),
757756
0);
758-
ASSERT_EQ(values_dim, 2);
757+
ASSERT_EQ(values_dim, 3);
759758
ASSERT_EQ(values_shape[0], n_samples);
760759
ASSERT_EQ(values_shape[1], n_features);
761-
ASSERT_EQ(bias_dim, 1);
760+
ASSERT_EQ(values_shape[2], 1);
761+
ASSERT_EQ(bias_dim, 2);
762762
ASSERT_EQ(bias_shape[0], n_samples);
763+
ASSERT_EQ(bias_shape[1], 1);
763764

764765
Json pred_config{Object{}};
765766
pred_config["type"] = Integer{2};

0 commit comments

Comments
 (0)