Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cpp] fix: to_if_else missing check for threshold when missing_type is not none #6818

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

wuciting
Copy link

@wuciting wuciting commented Feb 7, 2025

When missing_type is not None and fval is valid, the code before did not compare with the threshold value, resulting in a mismatch predict value in cpp code with predict value in python interface.

@wuciting wuciting changed the title fix: cpp to_if_else missing check for threshold when missing_type is … fix: cpp to_if_else missing check for threshold when missing_type is not none Feb 7, 2025
@wuciting wuciting changed the title fix: cpp to_if_else missing check for threshold when missing_type is not none [cpp] fix: to_if_else missing check for threshold when missing_type is not none Feb 7, 2025
Copy link
Collaborator

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your interest in LightGBM. I don't really understand this change... can you please add a test that fails on master but passes here? Or at least provide a minimal, reproducible example explaining the problem?

For example:

in a mismatch predict value in cpp code with predict value in python interface

The Python package calls into the C API to generate predictions, so I don't understand how this statement could be true.

@wuciting
Copy link
Author

wuciting commented Feb 8, 2025

The original feature was designed to convert the model parameter file into C++ source code in an if-else format. For more details, you can refer to the related pull request on GitHub: PR #469.

However, there is a bug in this feature: if the decision_type is not 2 (2 means None) but is either 8 or 10, the tool will generate incorrect C++ code. This incorrect code fails to check for thresholds when the values are not missing.

You can observe the difference in generated cpp code after my fix, particularly in Tree0.

Here's a small script generated by an LLM to create a model parameter file named lightgbm_model.txt:

import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.model_selection import train_test_split

# Generate training data
np.random.seed(42)  # Set random seed for reproducibility
size = 1000  # Increase data size to 1000 rows
num_features = 10  # Number of features

# Generate feature data, including some NaN values
X = np.random.rand(size, num_features)  # Generate 1000 rows and 10 columns of random numbers

# Randomly insert NaN values
nan_indices = np.random.choice(size * num_features, size=int(size * num_features * 0.1), replace=False)  # 10% NaN
for idx in nan_indices:
    row = idx // num_features
    col = idx % num_features
    X[row, col] = np.nan

# Generate target variable, assuming it is a linear combination plus some noise
y = 0.5 * X[:, 0] + 2 * X[:, 1] - 1.5 * X[:, 2] + np.random.normal(0, 0.1, size)  # Target variable is linearly related to features

# Convert data to pandas DataFrame
X_df = pd.DataFrame(X, columns=[f'feature{i+1}' for i in range(num_features)])
y_df = pd.Series(y)

# Split into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X_df, y_df, test_size=0.2, random_state=42)

# Create LightGBM dataset
train_data = lgb.Dataset(X_train, label=y_train, free_raw_data=False)
test_data = lgb.Dataset(X_test, label=y_test, free_raw_data=False)

# Set LightGBM parameters
params = {
    'objective': 'regression',
    'metric': 'mse',
    'verbosity': -1,  # Turn off LightGBM output
    'num_leaves': 3,
    'learning_rate': 0.1,
    'feature_fraction': 0.9,
    'max_bin': 255,
    'min_data_in_leaf': 20,
}

# Train model
model = lgb.train(params, train_data, num_boost_round=3)
model.save_model('lightgbm_model.txt')

print("Model parameters have been exported to lightgbm_model.txt")

lightgbm_model.txt:

tree
version=v4
num_class=1
num_tree_per_iteration=1
label_index=0
max_feature_idx=9
objective=regression
feature_names=feature1 feature2 feature3 feature4 feature5 feature6 feature7 feature8 feature9 feature10
feature_infos=[0.0030208713561612477:0.9966832779346797] [3.0718845382415871e-05:0.99971767328613059] [1.1634755366141114e-05:0.99889261203329061] [0.00040154402533987277:0.99946068105967312] [0.0013536257154681541:0.99941372577066656] [0.00022703821825553749:0.99890469976307672] [0.00041028943917487126:0.99413936122116753] [5.2826932296801132e-05:0.99762829740215653] [0.00090583524609022525:0.99742293905355417] [0.0011120941961874076:0.99955770325043858]
tree_sizes=394 402 402

Tree=0
num_leaves=3
num_cat=0
split_feature=1 2
split_gain=151 67.4877
threshold=0.43908158089896321 0.50555253269984324
decision_type=10 8
left_child=-1 -2
right_child=1 -3
leaf_value=0.3177485457277795 0.44806984060967808 0.36627160319740354
leaf_weight=395 190 215
leaf_count=395 190 215
internal_value=0.36174 0.404646
internal_weight=0 405
internal_count=800 405
is_linear=0
shrinkage=1


Tree=1
num_leaves=3
num_cat=0
split_feature=2 0
split_gain=80.9689 24.6402
threshold=0.47134327177395619 1.0000000180025095e-35
decision_type=8 10
left_child=1 -1
right_child=-2 -3
leaf_value=-0.036787212124237648 -0.026726479674212891 0.047840481589521179
leaf_weight=39 469 292
leaf_count=39 469 292
internal_value=0 0.0378692
internal_weight=0 331
internal_count=800 331
is_linear=0
shrinkage=0.1


Tree=2
num_leaves=3
num_cat=0
split_feature=1 0
split_gain=120.883 23.018
threshold=0.46947632301562686 1.0000000180025095e-35
decision_type=10 10
left_child=-1 -2
right_child=1 -3
leaf_value=-0.037347034887016677 -0.036852757930755618 0.048212644289095129
leaf_weight=416 35 349
leaf_count=416 35 349
internal_value=0 0.0404593
internal_weight=0 384
internal_count=800 384
is_linear=0
shrinkage=0.1


end of trees

feature_importances:
feature1=2
feature2=2
feature3=2

parameters:
[boosting: gbdt]
[objective: regression]
[metric: l2]
[tree_learner: serial]
[device_type: cpu]
[data_sample_strategy: bagging]
[data: ]
[valid: ]
[num_iterations: 3]
[learning_rate: 0.1]
[num_leaves: 3]
[num_threads: 0]
[seed: 0]
[deterministic: 0]
[force_col_wise: 0]
[force_row_wise: 0]
[histogram_pool_size: -1]
[max_depth: -1]
[min_data_in_leaf: 20]
[min_sum_hessian_in_leaf: 0.001]
[bagging_fraction: 1]
[pos_bagging_fraction: 1]
[neg_bagging_fraction: 1]
[bagging_freq: 0]
[bagging_seed: 3]
[feature_fraction: 0.9]
[feature_fraction_bynode: 1]
[feature_fraction_seed: 2]
[extra_trees: 0]
[extra_seed: 6]
[early_stopping_round: 0]
[early_stopping_min_delta: 0]
[first_metric_only: 0]
[max_delta_step: 0]
[lambda_l1: 0]
[lambda_l2: 0]
[linear_lambda: 0]
[min_gain_to_split: 0]
[drop_rate: 0.1]
[max_drop: 50]
[skip_drop: 0.5]
[xgboost_dart_mode: 0]
[uniform_drop: 0]
[drop_seed: 4]
[top_rate: 0.2]
[other_rate: 0.1]
[min_data_per_group: 100]
[max_cat_threshold: 32]
[cat_l2: 10]
[cat_smooth: 10]
[max_cat_to_onehot: 4]
[top_k: 20]
[monotone_constraints: ]
[monotone_constraints_method: basic]
[monotone_penalty: 0]
[feature_contri: ]
[forcedsplits_filename: ]
[refit_decay_rate: 0.9]
[cegb_tradeoff: 1]
[cegb_penalty_split: 0]
[cegb_penalty_feature_lazy: ]
[cegb_penalty_feature_coupled: ]
[path_smooth: 0]
[interaction_constraints: ]
[verbosity: -1]
[saved_feature_importance_type: 0]
[use_quantized_grad: 0]
[num_grad_quant_bins: 4]
[quant_train_renew_leaf: 0]
[stochastic_rounding: 1]
[linear_tree: 0]
[max_bin: 255]
[max_bin_by_feature: ]
[min_data_in_bin: 3]
[bin_construct_sample_cnt: 200000]
[data_random_seed: 1]
[is_enable_sparse: 1]
[enable_bundle: 1]
[use_missing: 1]
[zero_as_missing: 0]
[feature_pre_filter: 1]
[pre_partition: 0]
[two_round: 0]
[header: 0]
[label_column: ]
[weight_column: ]
[group_column: ]
[ignore_column: ]
[categorical_feature: ]
[forcedbins_filename: ]
[precise_float_parser: 0]
[parser_config_file: ]
[objective_seed: 5]
[num_class: 1]
[is_unbalance: 0]
[scale_pos_weight: 1]
[sigmoid: 1]
[boost_from_average: 1]
[reg_sqrt: 0]
[alpha: 0.9]
[fair_c: 1]
[poisson_max_delta_step: 0.7]
[tweedie_variance_power: 1.5]
[lambdarank_truncation_level: 30]
[lambdarank_norm: 1]
[label_gain: ]
[lambdarank_position_bias_regularization: 0]
[eval_at: ]
[multi_error_top_k: 1]
[auc_mu_weights: ]
[num_machines: 1]
[local_listen_port: 12400]
[time_out: 120]
[machine_list_filename: ]
[machines: ]
[gpu_platform_id: -1]
[gpu_device_id: -1]
[gpu_use_dp: 0]
[num_gpu: 1]

end of parameters

pandas_categorical:[]

Convert it to cpp file:

./lightgbm task=convert_model input_model=lightgbm_model.txt convert_model_language=cpp convert_model=a.cpp

Before

Cpp if-else file (a.cpp):

#include "gbdt.h"
#include <LightGBM/utils/common.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/prediction_early_stop.h>
#include <ctime>
#include <sstream>
#include <chrono>
#include <string>
#include <vector>
#include <utility>
namespace LightGBM {
double PredictTree0(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval)) {return 0.3177485457277795; } else { fval = arr[2];if (!std::isnan(fval)) {return 0.44806984060967808; } else { return 0.36627160319740354; } } }
double PredictTree0ByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval)) {return 0.3177485457277795; } else { fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval)) {return 0.44806984060967808; } else { return 0.36627160319740354; } } }

double PredictTree1(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[2];if (!std::isnan(fval)) {fval = arr[0];if (std::isnan(fval)) {return -0.036787212124237648; } else { return 0.047840481589521179; } } else { return -0.026726479674212891; } }
double PredictTree1ByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval)) {fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval)) {return -0.036787212124237648; } else { return 0.047840481589521179; } } else { return -0.026726479674212891; } }

double PredictTree2(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval)) {return -0.037347034887016677; } else { fval = arr[0];if (std::isnan(fval)) {return -0.036852757930755618; } else { return 0.048212644289095129; } } }
double PredictTree2ByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval)) {return -0.037347034887016677; } else { fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval)) {return -0.036852757930755618; } else { return 0.048212644289095129; } } }

double (*PredictTreePtr[])(const double*) = { PredictTree0 , PredictTree1 , PredictTree2 };

void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {
        int early_stop_round_counter = 0;
        std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);
        for (int i = 0; i < num_iteration_for_pred_; ++i) {
                for (int k = 0; k < num_tree_per_iteration_; ++k) {
                        output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);
                }
                ++early_stop_round_counter;
                if (early_stop->round_period == early_stop_round_counter) {
                        if (early_stop->callback_function(output, num_tree_per_iteration_))
                                return;
                        early_stop_round_counter = 0;
                }
        }
}

double (*PredictTreeByMapPtr[])(const std::unordered_map<int, double>&) = { PredictTree0ByMap , PredictTree1ByMap , PredictTree2ByMap };

void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {
        int early_stop_round_counter = 0;
        std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);
        for (int i = 0; i < num_iteration_for_pred_; ++i) {
                for (int k = 0; k < num_tree_per_iteration_; ++k) {
                        output[k] += (*PredictTreeByMapPtr[i * num_tree_per_iteration_ + k])(features);
                }
                ++early_stop_round_counter;
                if (early_stop->round_period == early_stop_round_counter) {
                        if (early_stop->callback_function(output, num_tree_per_iteration_))
                                return;
                        early_stop_round_counter = 0;
                }
        }
}

void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {
        PredictRaw(features, output, early_stop);
        if (average_output_) {
                for (int k = 0; k < num_tree_per_iteration_; ++k) {
                        output[k] /= num_iteration_for_pred_;
                }
        }
        if (objective_function_ != nullptr) {
                objective_function_->ConvertOutput(output, output);
        }
}

void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {
        PredictRawByMap(features, output, early_stop);
        if (average_output_) {
                for (int k = 0; k < num_tree_per_iteration_; ++k) {
                        output[k] /= num_iteration_for_pred_;
                }
        }
        if (objective_function_ != nullptr) {
                objective_function_->ConvertOutput(output, output);
        }
}

double PredictTree0Leaf(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval)) {return 0; } else { fval = arr[2];if (!std::isnan(fval)) {return 1; } else { return 2; } } }
double PredictTree0LeafByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval)) {return 0; } else { fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval)) {return 1; } else { return 2; } } }

double PredictTree1Leaf(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[2];if (!std::isnan(fval)) {fval = arr[0];if (std::isnan(fval)) {return 0; } else { return 2; } } else { return 1; } }
double PredictTree1LeafByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval)) {fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval)) {return 0; } else { return 2; } } else { return 1; } }

double PredictTree2Leaf(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval)) {return 0; } else { fval = arr[0];if (std::isnan(fval)) {return 1; } else { return 2; } } }
double PredictTree2LeafByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval)) {return 0; } else { fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval)) {return 1; } else { return 2; } } }

double (*PredictTreeLeafPtr[])(const double*) = { PredictTree0Leaf , PredictTree1Leaf , PredictTree2Leaf };

void GBDT::PredictLeafIndex(const double* features, double *output) const {
        int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
        for (int i = 0; i < total_tree; ++i) {
                output[i] = (*PredictTreeLeafPtr[i])(features);
        }
}
double (*PredictTreeLeafByMapPtr[])(const std::unordered_map<int, double>&) = { PredictTree0LeafByMap , PredictTree1LeafByMap , PredictTree2LeafByMap };

void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const {
        int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
        for (int i = 0; i < total_tree; ++i) {
                output[i] = (*PredictTreeLeafByMapPtr[i])(features);
        }
}
}  // namespace LightGBM

Take a look at Tree0, it only checks for fval is nan or not, does not check for the threshold:

double PredictTree0(const double * arr) {
  const std::vector < uint32_t > cat_threshold = {};
  double fval = 0.0 f;
  fval = arr[1];
  if (std::isnan(fval)) {
    return 0.3177485457277795;
  } else {
    fval = arr[2];
    if (!std::isnan(fval)) {
      return 0.44806984060967808;
    } else {
      return 0.36627160319740354;
    }
  }
}

After my fix

whole cpp file (a.cpp):

#include "gbdt.h"
#include <LightGBM/utils/common.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/prediction_early_stop.h>
#include <ctime>
#include <sstream>
#include <chrono>
#include <string>
#include <vector>
#include <utility>
namespace LightGBM {
double PredictTree0(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval) || fval <= 0.43908158089896321) {return 0.3177485457277795; } else { fval = arr[2];if (!std::isnan(fval) && fval <= 0.50555253269984324) {return 0.44806984060967808; } else { return 0.36627160319740354; } } }
double PredictTree0ByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval) || fval <= 0.43908158089896321) {return 0.3177485457277795; } else { fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval) && fval <= 0.50555253269984324) {return 0.44806984060967808; } else { return 0.36627160319740354; } } }

double PredictTree1(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[2];if (!std::isnan(fval) && fval <= 0.47134327177395619) {fval = arr[0];if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return -0.036787212124237648; } else { return 0.047840481589521179; } } else { return -0.026726479674212891; } }
double PredictTree1ByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval) && fval <= 0.47134327177395619) {fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return -0.036787212124237648; } else { return 0.047840481589521179; } } else { return -0.026726479674212891; } }

double PredictTree2(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval) || fval <= 0.46947632301562686) {return -0.037347034887016677; } else { fval = arr[0];if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return -0.036852757930755618; } else { return 0.048212644289095129; } } }
double PredictTree2ByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval) || fval <= 0.46947632301562686) {return -0.037347034887016677; } else { fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return -0.036852757930755618; } else { return 0.048212644289095129; } } }

double (*PredictTreePtr[])(const double*) = { PredictTree0 , PredictTree1 , PredictTree2 };

void GBDT::PredictRaw(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {
        int early_stop_round_counter = 0;
        std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);
        for (int i = 0; i < num_iteration_for_pred_; ++i) {
                for (int k = 0; k < num_tree_per_iteration_; ++k) {
                        output[k] += (*PredictTreePtr[i * num_tree_per_iteration_ + k])(features);
                }
                ++early_stop_round_counter;
                if (early_stop->round_period == early_stop_round_counter) {
                        if (early_stop->callback_function(output, num_tree_per_iteration_))
                                return;
                        early_stop_round_counter = 0;
                }
        }
}

double (*PredictTreeByMapPtr[])(const std::unordered_map<int, double>&) = { PredictTree0ByMap , PredictTree1ByMap , PredictTree2ByMap };

void GBDT::PredictRawByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {
        int early_stop_round_counter = 0;
        std::memset(output, 0, sizeof(double) * num_tree_per_iteration_);
        for (int i = 0; i < num_iteration_for_pred_; ++i) {
                for (int k = 0; k < num_tree_per_iteration_; ++k) {
                        output[k] += (*PredictTreeByMapPtr[i * num_tree_per_iteration_ + k])(features);
                }
                ++early_stop_round_counter;
                if (early_stop->round_period == early_stop_round_counter) {
                        if (early_stop->callback_function(output, num_tree_per_iteration_))
                                return;
                        early_stop_round_counter = 0;
                }
        }
}

void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {
        PredictRaw(features, output, early_stop);
        if (average_output_) {
                for (int k = 0; k < num_tree_per_iteration_; ++k) {
                        output[k] /= num_iteration_for_pred_;
                }
        }
        if (objective_function_ != nullptr) {
                objective_function_->ConvertOutput(output, output);
        }
}

void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double* output, const PredictionEarlyStopInstance* early_stop) const {
        PredictRawByMap(features, output, early_stop);
        if (average_output_) {
                for (int k = 0; k < num_tree_per_iteration_; ++k) {
                        output[k] /= num_iteration_for_pred_;
                }
        }
        if (objective_function_ != nullptr) {
                objective_function_->ConvertOutput(output, output);
        }
}

double PredictTree0Leaf(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval) || fval <= 0.43908158089896321) {return 0; } else { fval = arr[2];if (!std::isnan(fval) && fval <= 0.50555253269984324) {return 1; } else { return 2; } } }
double PredictTree0LeafByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval) || fval <= 0.43908158089896321) {return 0; } else { fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval) && fval <= 0.50555253269984324) {return 1; } else { return 2; } } }

double PredictTree1Leaf(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[2];if (!std::isnan(fval) && fval <= 0.47134327177395619) {fval = arr[0];if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return 0; } else { return 2; } } else { return 1; } }
double PredictTree1LeafByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(2) > 0 ? arr.at(2) : 0.0f;if (!std::isnan(fval) && fval <= 0.47134327177395619) {fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return 0; } else { return 2; } } else { return 1; } }

double PredictTree2Leaf(const double* arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr[1];if (std::isnan(fval) || fval <= 0.46947632301562686) {return 0; } else { fval = arr[0];if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return 1; } else { return 2; } } }
double PredictTree2LeafByMap(const std::unordered_map<int, double>& arr) { const std::vector<uint32_t> cat_threshold = {};double fval = 0.0f; fval = arr.count(1) > 0 ? arr.at(1) : 0.0f;if (std::isnan(fval) || fval <= 0.46947632301562686) {return 0; } else { fval = arr.count(0) > 0 ? arr.at(0) : 0.0f;if (std::isnan(fval) || fval <= 1.0000000180025095e-35) {return 1; } else { return 2; } } }

double (*PredictTreeLeafPtr[])(const double*) = { PredictTree0Leaf , PredictTree1Leaf , PredictTree2Leaf };

void GBDT::PredictLeafIndex(const double* features, double *output) const {
        int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
        for (int i = 0; i < total_tree; ++i) {
                output[i] = (*PredictTreeLeafPtr[i])(features);
        }
}
double (*PredictTreeLeafByMapPtr[])(const std::unordered_map<int, double>&) = { PredictTree0LeafByMap , PredictTree1LeafByMap , PredictTree2LeafByMap };

void GBDT::PredictLeafIndexByMap(const std::unordered_map<int, double>& features, double* output) const {
        int total_tree = num_iteration_for_pred_ * num_tree_per_iteration_;
        for (int i = 0; i < total_tree; ++i) {
                output[i] = (*PredictTreeLeafByMapPtr[i])(features);
        }
}
}  // namespace LightGBM

Tree0:

double PredictTree0(const double * arr) {
  const std::vector < uint32_t > cat_threshold = {};
  double fval = 0.0 f;
  fval = arr[1];
  if (std::isnan(fval) || fval <= 0.43908158089896321) {  // add check for threshold
    return 0.3177485457277795;
  } else {
    fval = arr[2];
    if (!std::isnan(fval) && fval <= 0.50555253269984324) {  // add check for threshold
      return 0.44806984060967808;
    } else {
      return 0.36627160319740354;
    }
  }
}

@wuciting wuciting requested a review from jameslamb February 8, 2025 03:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants