-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[cpp] fix: to_if_else missing check for threshold when missing_type is not none #6818
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your interest in LightGBM. I don't really understand this change... can you please add a test that fails on master
but passes here? Or at least provide a minimal, reproducible example explaining the problem?
For example:
in a mismatch predict value in cpp code with predict value in python interface
The Python package calls into the C API to generate predictions, so I don't understand how this statement could be true.
The original feature was designed to convert the model parameter file into C++ source code in an if-else format. For more details, you can refer to the related pull request on GitHub: PR #469. However, there is a bug in this feature: if the decision_type is not 2 (2 means None) but is either 8 or 10, the tool will generate incorrect C++ code. This incorrect code fails to check for thresholds when the values are not missing. You can observe the difference in generated cpp code after my fix, particularly in Tree0. Here's a small script generated by an LLM to create a model parameter file named lightgbm_model.txt: import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.model_selection import train_test_split
# Generate training data
np.random.seed(42) # Set random seed for reproducibility
size = 1000 # Increase data size to 1000 rows
num_features = 10 # Number of features
# Generate feature data, including some NaN values
X = np.random.rand(size, num_features) # Generate 1000 rows and 10 columns of random numbers
# Randomly insert NaN values
nan_indices = np.random.choice(size * num_features, size=int(size * num_features * 0.1), replace=False) # 10% NaN
for idx in nan_indices:
row = idx // num_features
col = idx % num_features
X[row, col] = np.nan
# Generate target variable, assuming it is a linear combination plus some noise
y = 0.5 * X[:, 0] + 2 * X[:, 1] - 1.5 * X[:, 2] + np.random.normal(0, 0.1, size) # Target variable is linearly related to features
# Convert data to pandas DataFrame
X_df = pd.DataFrame(X, columns=[f'feature{i+1}' for i in range(num_features)])
y_df = pd.Series(y)
# Split into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X_df, y_df, test_size=0.2, random_state=42)
# Create LightGBM dataset
train_data = lgb.Dataset(X_train, label=y_train, free_raw_data=False)
test_data = lgb.Dataset(X_test, label=y_test, free_raw_data=False)
# Set LightGBM parameters
params = {
'objective': 'regression',
'metric': 'mse',
'verbosity': -1, # Turn off LightGBM output
'num_leaves': 3,
'learning_rate': 0.1,
'feature_fraction': 0.9,
'max_bin': 255,
'min_data_in_leaf': 20,
}
# Train model
model = lgb.train(params, train_data, num_boost_round=3)
model.save_model('lightgbm_model.txt')
print("Model parameters have been exported to lightgbm_model.txt") lightgbm_model.txt:
Convert it to cpp file: ./lightgbm task=convert_model input_model=lightgbm_model.txt convert_model_language=cpp convert_model=a.cpp BeforeCpp if-else file (a.cpp): #include "gbdt.h"
#include <LightGBM/utils/common.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/prediction_early_stop.h>
#include <ctime>
#include <sstream>
#include <chrono>
#include <string>
#include <vector>
#include <utility>
namespace LightGBM {
double PredictTree0(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval)) {return 0.3177485457277795; } else { fval = arr[2];if (!std::isnan(fval)) {return 0.44806984060967808; } else { return 0.36627160319740354; } } }
double PredictTree0ByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval)) {return 0.3177485457277795; } else { fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval)) {return 0.44806984060967808; } else { return 0.36627160319740354; } } }
double PredictTree1(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[2];if (!std::isnan(fval)) {fval = arr[0];if (std::isnan(fval)) {return -0.036787212124237648; } else { return 0.047840481589521179; } } else { return -0.026726479674212891; } }
double PredictTree1ByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval)) {fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval)) {return -0.036787212124237648; } else { return 0.047840481589521179; } } else { return -0.026726479674212891; } }
double PredictTree2(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval)) {return -0.037347034887016677; } else { fval = arr[0];if (std::isnan(fval)) {return -0.036852757930755618; } else { return 0.048212644289095129; } } }
double PredictTree2ByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval)) {return -0.037347034887016677; } else { fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval)) {return -0.036852757930755618; } else { return 0.048212644289095129; } } }
double (*PredictTreePtr[])(const double*) = { PredictTree0 , PredictTree1 , PredictTree2 };
void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {
int early_stop_round_counter = 0;
std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);
for (int i = 0; i < num_iteration_for_pred_; ++i) {
for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);
}
++early_stop_round_counter;
if (early_stop->round_period == early_stop_round_counter) {
if (early_stop->callback_function(output, num_tree_per_iteration_))
return;
early_stop_round_counter = 0;
}
}
}
double (*PredictTreeByMapPtr[])(const std::unordered_map<int, double>&) = { PredictTree0ByMap , PredictTree1ByMap , PredictTree2ByMap };
void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {
int early_stop_round_counter = 0;
std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);
for (int i = 0; i < num_iteration_for_pred_; ++i) {
for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] += (*PredictTreeByMapPtr[i * num_tree_per_iteration_ + k])(features);
}
++early_stop_round_counter;
if (early_stop->round_period == early_stop_round_counter) {
if (early_stop->callback_function(output, num_tree_per_iteration_))
return;
early_stop_round_counter = 0;
}
}
}
void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {
PredictRaw(features, output, early_stop);
if (average_output_) {
for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] /= num_iteration_for_pred_;
}
}
if (objective_function_ != nullptr) {
objective_function_->ConvertOutput(output, output);
}
}
void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {
PredictRawByMap(features, output, early_stop);
if (average_output_) {
for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] /= num_iteration_for_pred_;
}
}
if (objective_function_ != nullptr) {
objective_function_->ConvertOutput(output, output);
}
}
double PredictTree0Leaf(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval)) {return 0; } else { fval = arr[2];if (!std::isnan(fval)) {return 1; } else { return 2; } } }
double PredictTree0LeafByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval)) {return 0; } else { fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval)) {return 1; } else { return 2; } } }
double PredictTree1Leaf(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[2];if (!std::isnan(fval)) {fval = arr[0];if (std::isnan(fval)) {return 0; } else { return 2; } } else { return 1; } }
double PredictTree1LeafByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval)) {fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval)) {return 0; } else { return 2; } } else { return 1; } }
double PredictTree2Leaf(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval)) {return 0; } else { fval = arr[0];if (std::isnan(fval)) {return 1; } else { return 2; } } }
double PredictTree2LeafByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval)) {return 0; } else { fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval)) {return 1; } else { return 2; } } }
double (*PredictTreeLeafPtr[])(const double*) = { PredictTree0Leaf , PredictTree1Leaf , PredictTree2Leaf };
void GBDT::PredictLeafIndex(const double* features, double *output) const {
int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
for (int i = 0; i < total_tree; ++i) {
output[i] = (*PredictTreeLeafPtr[i])(features);
}
}
double (*PredictTreeLeafByMapPtr[])(const std::unordered_map<int, double>&) = { PredictTree0LeafByMap , PredictTree1LeafByMap , PredictTree2LeafByMap };
void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const {
int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
for (int i = 0; i < total_tree; ++i) {
output[i] = (*PredictTreeLeafByMapPtr[i])(features);
}
}
} // namespace LightGBM Take a look at Tree0, it only checks for fval is nan or not, does not check for the threshold: double PredictTree0(const double * arr) {
const std::vector < uint32_t > cat_threshold = {};
double fval = 0.0 f;
fval = arr[1];
if (std::isnan(fval)) {
return 0.3177485457277795;
} else {
fval = arr[2];
if (!std::isnan(fval)) {
return 0.44806984060967808;
} else {
return 0.36627160319740354;
}
}
} After my fixwhole cpp file (a.cpp): #include "gbdt.h"
#include <LightGBM/utils/common.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/prediction_early_stop.h>
#include <ctime>
#include <sstream>
#include <chrono>
#include <string>
#include <vector>
#include <utility>
namespace LightGBM {
double PredictTree0(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval) || fval <= 0.43908158089896321) {return 0.3177485457277795; } else { fval = arr[2];if (!std::isnan(fval) && fval <= 0.50555253269984324) {return 0.44806984060967808; } else { return 0.36627160319740354; } } }
double PredictTree0ByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval) || fval <= 0.43908158089896321) {return 0.3177485457277795; } else { fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval) && fval <= 0.50555253269984324) {return 0.44806984060967808; } else { return 0.36627160319740354; } } }
double PredictTree1(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[2];if (!std::isnan(fval) && fval <= 0.47134327177395619) {fval = arr[0];if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return -0.036787212124237648; } else { return 0.047840481589521179; } } else { return -0.026726479674212891; } }
double PredictTree1ByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval) && fval <= 0.47134327177395619) {fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return -0.036787212124237648; } else { return 0.047840481589521179; } } else { return -0.026726479674212891; } }
double PredictTree2(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval) || fval <= 0.46947632301562686) {return -0.037347034887016677; } else { fval = arr[0];if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return -0.036852757930755618; } else { return 0.048212644289095129; } } }
double PredictTree2ByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval) || fval <= 0.46947632301562686) {return -0.037347034887016677; } else { fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return -0.036852757930755618; } else { return 0.048212644289095129; } } }
double (*PredictTreePtr[])(const double*) = { PredictTree0 , PredictTree1 , PredictTree2 };
void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {
int early_stop_round_counter = 0;
std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);
for (int i = 0; i < num_iteration_for_pred_; ++i) {
for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);
}
++early_stop_round_counter;
if (early_stop->round_period == early_stop_round_counter) {
if (early_stop->callback_function(output, num_tree_per_iteration_))
return;
early_stop_round_counter = 0;
}
}
}
double (*PredictTreeByMapPtr[])(const std::unordered_map<int, double>&) = { PredictTree0ByMap , PredictTree1ByMap , PredictTree2ByMap };
void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {
int early_stop_round_counter = 0;
std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);
for (int i = 0; i < num_iteration_for_pred_; ++i) {
for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] += (*PredictTreeByMapPtr[i * num_tree_per_iteration_ + k])(features);
}
++early_stop_round_counter;
if (early_stop->round_period == early_stop_round_counter) {
if (early_stop->callback_function(output, num_tree_per_iteration_))
return;
early_stop_round_counter = 0;
}
}
}
void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {
PredictRaw(features, output, early_stop);
if (average_output_) {
for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] /= num_iteration_for_pred_;
}
}
if (objective_function_ != nullptr) {
objective_function_->ConvertOutput(output, output);
}
}
void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {
PredictRawByMap(features, output, early_stop);
if (average_output_) {
for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] /= num_iteration_for_pred_;
}
}
if (objective_function_ != nullptr) {
objective_function_->ConvertOutput(output, output);
}
}
double PredictTree0Leaf(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval) || fval <= 0.43908158089896321) {return 0; } else { fval = arr[2];if (!std::isnan(fval) && fval <= 0.50555253269984324) {return 1; } else { return 2; } } }
double PredictTree0LeafByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval) || fval <= 0.43908158089896321) {return 0; } else { fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval) && fval <= 0.50555253269984324) {return 1; } else { return 2; } } }
double PredictTree1Leaf(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[2];if (!std::isnan(fval) && fval <= 0.47134327177395619) {fval = arr[0];if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return 0; } else { return 2; } } else { return 1; } }
double PredictTree1LeafByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval) && fval <= 0.47134327177395619) {fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return 0; } else { return 2; } } else { return 1; } }
double PredictTree2Leaf(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval) || fval <= 0.46947632301562686) {return 0; } else { fval = arr[0];if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return 1; } else { return 2; } } }
double PredictTree2LeafByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval) || fval <= 0.46947632301562686) {return 0; } else { fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return 1; } else { return 2; } } }
double (*PredictTreeLeafPtr[])(const double*) = { PredictTree0Leaf , PredictTree1Leaf , PredictTree2Leaf };
void GBDT::PredictLeafIndex(const double* features, double *output) const {
int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
for (int i = 0; i < total_tree; ++i) {
output[i] = (*PredictTreeLeafPtr[i])(features);
}
}
double (*PredictTreeLeafByMapPtr[])(const std::unordered_map<int, double>&) = { PredictTree0LeafByMap , PredictTree1LeafByMap , PredictTree2LeafByMap };
void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const {
int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
for (int i = 0; i < total_tree; ++i) {
output[i] = (*PredictTreeLeafByMapPtr[i])(features);
}
}
} // namespace LightGBM Tree0: double PredictTree0(const double * arr) {
const std::vector < uint32_t > cat_threshold = {};
double fval = 0.0 f;
fval = arr[1];
if (std::isnan(fval) || fval <= 0.43908158089896321) { // add check for threshold
return 0.3177485457277795;
} else {
fval = arr[2];
if (!std::isnan(fval) && fval <= 0.50555253269984324) { // add check for threshold
return 0.44806984060967808;
} else {
return 0.36627160319740354;
}
}
} |
When missing_type is not None and fval is valid, the code before did not compare with the threshold value, resulting in a mismatch predict value in cpp code with predict value in python interface.