Skip to content

Added Random Forest Regressor as an additional prediction model. #12767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
44 changes: 44 additions & 0 deletions machine_learning/forecasting/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@

from warnings import simplefilter

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import Normalizer
from sklearn.svm import SVR
from statsmodels.tsa.statespace.sarimax import SARIMAX
Expand Down Expand Up @@ -78,6 +80,29 @@ def support_vector_regressor(x_train: list, x_test: list, train_user: list) -> f
return float(y_pred[0])


def random_forest_regressor(x_train: list, x_test: list, train_user: list) -> float:
"""
Fourth method: Random Forest Regressor
Random Forest is an ensemble learning method for regression that operates
by constructing a multitude of decision trees at training time and outputting
the mean prediction of the individual trees.

It is more robust than a single decision tree and less prone to overfitting.
Good for capturing nonlinear relationships in data.

input : training data (date, total_event) in list of float
where x = list of set (date and total event)
output : list of total user prediction in float

>>> random_forest_regressor([[5,2],[1,5],[6,2]], [[3,2]], [2,1,4])
1.95
"""
model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(x_train, train_user)
prediction = model.predict(x_test)
return float(prediction[0])


def interquartile_range_checker(train_user: list) -> float:
"""
Optional method: interquatile range
Expand Down Expand Up @@ -120,6 +145,22 @@ def data_safety_checker(list_vote: list, actual_result: float) -> bool:
return safe > not_safe


def plot_forecast(actual, predictions):
plt.figure(figsize=(10, 5))
plt.plot(range(len(actual)), actual, label="Actual")
plt.plot(len(actual), predictions[0], "ro", label="Linear Reg")
plt.plot(len(actual), predictions[1], "go", label="SARIMAX")
plt.plot(len(actual), predictions[2], "bo", label="SVR")
plt.plot(len(actual), predictions[3], "yo", label="RF")
plt.legend()
plt.title("Data Safety Forecast")
plt.xlabel("Days")
plt.ylabel("Normalized User Count")
plt.grid(True)
plt.tight_layout()
plt.show()


if __name__ == "__main__":
"""
data column = total user in a day, how much online event held in one day,
Expand Down Expand Up @@ -155,8 +196,11 @@ def data_safety_checker(list_vote: list, actual_result: float) -> bool:
),
sarimax_predictor(train_user, train_match, test_match),
support_vector_regressor(x_train, x_test, train_user),
random_forest_regressor(x_train, x_test, train_user),
]

# check the safety of today's data
not_str = "" if data_safety_checker(res_vote, test_user[0]) else "not "
print(f"Today's data is {not_str}safe.")

plot_forecast(train_user, res_vote)