diff --git a/book/ate/propensity_score_and_dml.ipynb b/book/ate/propensity_score_and_dml.ipynb index c2e7639..4aa819e 100644 --- a/book/ate/propensity_score_and_dml.ipynb +++ b/book/ate/propensity_score_and_dml.ipynb @@ -11,9 +11,2650 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "- Matching\n", - "- IPW, AIPW, Doubly Robust Estimator\n", - "- Double Machine Learning (비모수 버전의 Regression 처럼 활용 가능)" + "- IPW, AIPW, Doubly Robust Estimator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### **Propensity Score 추정**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "출처: https://matheusfacure.github.io/python-causality-handbook/11-Propensity-Score.html" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "IPW와 AIPW, Doubly Robust 모두 Propensity Score를 활용한 개념들이기 때문에 먼저 Propensity Score부터 간단하게 구해보겠습니다." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "from causalinference import CausalModel" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
| \n", + " | schoolid | \n", + "intervention | \n", + "achievement_score | \n", + "success_expect | \n", + "ethnicity | \n", + "gender | \n", + "frst_in_family | \n", + "school_urbanicity | \n", + "school_mindset | \n", + "school_achievement | \n", + "school_ethnic_minority | \n", + "school_poverty | \n", + "school_size | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 259 | \n", + "73 | \n", + "1 | \n", + "1.480828 | \n", + "5 | \n", + "1 | \n", + "2 | \n", + "0 | \n", + "1 | \n", + "-0.462945 | \n", + "0.652608 | \n", + "-0.515202 | \n", + "-0.169849 | \n", + "0.173954 | \n", + "
| 3435 | \n", + "76 | \n", + "0 | \n", + "-0.987277 | \n", + "5 | \n", + "13 | \n", + "1 | \n", + "1 | \n", + "4 | \n", + "0.334544 | \n", + "0.648586 | \n", + "-1.310927 | \n", + "0.224077 | \n", + "-0.426757 | \n", + "
| 9963 | \n", + "4 | \n", + "0 | \n", + "-0.152340 | \n", + "5 | \n", + "2 | \n", + "2 | \n", + "1 | \n", + "0 | \n", + "-2.289636 | \n", + "0.190797 | \n", + "0.875012 | \n", + "-0.724801 | \n", + "0.761781 | \n", + "
| 4488 | \n", + "67 | \n", + "0 | \n", + "0.358336 | \n", + "6 | \n", + "14 | \n", + "1 | \n", + "0 | \n", + "4 | \n", + "-1.115337 | \n", + "1.053089 | \n", + "0.315755 | \n", + "0.054586 | \n", + "1.862187 | \n", + "
| 2637 | \n", + "16 | \n", + "1 | \n", + "1.360920 | \n", + "6 | \n", + "4 | \n", + "1 | \n", + "0 | \n", + "1 | \n", + "-0.538975 | \n", + "1.433826 | \n", + "-0.033161 | \n", + "-0.982274 | \n", + "1.591641 | \n", + "
AIPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, outcome_covariates=None, outcome_model=Standardization(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, encode_treatment=False, predict_proba=False,\n",
+ " learner=Ridge()), overlap_weighting=False, predict_proba=False, weight_covariates=None,\n",
+ " weight_model=IPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, clip_max=0.99, clip_min=0.01, use_stabilized=True, verbose=False,\n",
+ " learner=LogisticRegression(max_iter=1000)))In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. AIPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, outcome_covariates=None, outcome_model=Standardization(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, encode_treatment=False, predict_proba=False,\n",
+ " learner=Ridge()), overlap_weighting=False, predict_proba=False, weight_covariates=None,\n",
+ " weight_model=IPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, clip_max=0.99, clip_min=0.01, use_stabilized=True, verbose=False,\n",
+ " learner=LogisticRegression(max_iter=1000)))Standardization(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, encode_treatment=False, predict_proba=False,\n",
+ " learner=Ridge())Ridge()
Ridge()
IPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, clip_max=0.99, clip_min=0.01, use_stabilized=True, verbose=False,\n",
+ " learner=LogisticRegression(max_iter=1000))LogisticRegression(max_iter=1000)
LogisticRegression(max_iter=1000)
AIPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, outcome_covariates=None, outcome_model=Standardization(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, encode_treatment=False, predict_proba=False,\n",
+ " learner=Ridge()), overlap_weighting=True, predict_proba=False, weight_covariates=None,\n",
+ " weight_model=IPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, clip_max=0.99, clip_min=0.01, use_stabilized=True, verbose=False,\n",
+ " learner=LogisticRegression(max_iter=1000)))In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. AIPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, outcome_covariates=None, outcome_model=Standardization(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, encode_treatment=False, predict_proba=False,\n",
+ " learner=Ridge()), overlap_weighting=True, predict_proba=False, weight_covariates=None,\n",
+ " weight_model=IPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, clip_max=0.99, clip_min=0.01, use_stabilized=True, verbose=False,\n",
+ " learner=LogisticRegression(max_iter=1000)))Standardization(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, encode_treatment=False, predict_proba=False,\n",
+ " learner=Ridge())Ridge()
Ridge()
IPW(_doc_link_module=sklearn, _doc_link_template=https://scikit-learn.org/1.5/modules/generated/{estimator_module}.{estimator_name}.html, _doc_link_url_param_generator=None, clip_max=0.99, clip_min=0.01, use_stabilized=True, verbose=False,\n",
+ " learner=LogisticRegression(max_iter=1000))LogisticRegression(max_iter=1000)
LogisticRegression(max_iter=1000)
| \n", + " | temp | \n", + "weekday | \n", + "cost | \n", + "price | \n", + "sales | \n", + "
|---|---|---|---|---|---|
| 0 | \n", + "17.3 | \n", + "6 | \n", + "1.5 | \n", + "5.6 | \n", + "173 | \n", + "
| 1 | \n", + "25.4 | \n", + "3 | \n", + "0.3 | \n", + "4.9 | \n", + "196 | \n", + "
| 2 | \n", + "23.3 | \n", + "5 | \n", + "1.5 | \n", + "7.6 | \n", + "207 | \n", + "
| 3 | \n", + "26.9 | \n", + "1 | \n", + "0.3 | \n", + "5.3 | \n", + "241 | \n", + "
| 4 | \n", + "20.2 | \n", + "1 | \n", + "1.0 | \n", + "7.2 | \n", + "227 | \n", + "
| coef | std err | t | P>|t| | [0.025 | 0.975] | \n", + "|
|---|---|---|---|---|---|---|
| Intercept | 0.0106 | 0.072 | 0.148 | 0.883 | -0.131 | 0.152 | \n", + "
| price_res | -3.9228 | 0.071 | -54.962 | 0.000 | -4.063 | -3.783 | \n", + "
| coef | std err | t | P>|t| | [0.025 | 0.975] | \n", + "|
|---|---|---|---|---|---|---|
| Intercept | 192.9679 | 1.013 | 190.414 | 0.000 | 190.981 | 194.954 | \n", + "
| price | 1.2294 | 0.162 | 7.575 | 0.000 | 0.911 | 1.547 | \n", + "
| \n", + " | nifa | \n", + "net_tfa | \n", + "tw | \n", + "age | \n", + "inc | \n", + "fsize | \n", + "educ | \n", + "db | \n", + "marr | \n", + "twoearn | \n", + "e401 | \n", + "p401 | \n", + "pira | \n", + "hown | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", + "0.0 | \n", + "0.0 | \n", + "4500.0 | \n", + "47 | \n", + "6765.0 | \n", + "2 | \n", + "8 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "
| 1 | \n", + "6215.0 | \n", + "1015.0 | \n", + "22390.0 | \n", + "36 | \n", + "28452.0 | \n", + "1 | \n", + "16 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "
| 2 | \n", + "0.0 | \n", + "-2000.0 | \n", + "-2000.0 | \n", + "37 | \n", + "3300.0 | \n", + "6 | \n", + "12 | \n", + "1 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "
| 3 | \n", + "15000.0 | \n", + "15000.0 | \n", + "155000.0 | \n", + "58 | \n", + "52590.0 | \n", + "2 | \n", + "16 | \n", + "0 | \n", + "1 | \n", + "1 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "
| 4 | \n", + "0.0 | \n", + "0.0 | \n", + "58000.0 | \n", + "32 | \n", + "21804.0 | \n", + "1 | \n", + "11 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "
| \n", + " | coef | \n", + "std err | \n", + "t | \n", + "P>|t| | \n", + "2.5 % | \n", + "97.5 % | \n", + "
|---|---|---|---|---|---|---|
| e401 | \n", + "6229.135188 | \n", + "1461.694009 | \n", + "4.261586 | \n", + "0.00002 | \n", + "3364.267574 | \n", + "9094.002802 | \n", + "