diff --git a/book/ate/propensity_score_and_dml.ipynb b/book/ate/propensity_score_and_dml.ipynb index c2e7639..417b71e 100644 --- a/book/ate/propensity_score_and_dml.ipynb +++ b/book/ate/propensity_score_and_dml.ipynb @@ -11,9 +11,1149 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "- Matching\n", - "- IPW, AIPW, Doubly Robust Estimator\n", - "- Double Machine Learning (비모수 버전의 Regression 처럼 활용 가능)" + "## Propensity Score Matching - Binary Treatment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "미국 공립 고등학교에서 수행된 성장 마인드셋(Growth Mindset) 무작위 연구를 기반으로 생성된 시뮬레이션 데이터입니다.\n", + "\n", + "- **처치(intervention)**: 성장 마인드 교육 세미나 참여 여부\n", + "- **결과(achievement_score)**: 학업 성취도 점수\n", + "\n", + "그 외 공변량들은 학생과 학교의 배경 특성을 반영합니다.\n", + "\n", + "**참고**: [Athey & Wager (2019)](https://arxiv.org/pdf/1902.07409)" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
schoolidinterventionachievement_scoresuccess_expectethnicitygenderfrst_in_familyschool_urbanicityschool_mindsetschool_achievementschool_ethnic_minorityschool_povertyschool_size
07610.277359642140.3345440.648586-1.3109270.224077-0.426757
1761-0.4496464122140.3345440.648586-1.3109270.224077-0.426757
27610.769703642040.3345440.648586-1.3109270.224077-0.426757
3761-0.121763642040.3345440.648586-1.3109270.224077-0.426757
47611.526147641040.3345440.648586-1.3109270.224077-0.426757
\n", + "
" + ], + "text/plain": [ + " schoolid intervention achievement_score success_expect ethnicity \\\n", + "0 76 1 0.277359 6 4 \n", + "1 76 1 -0.449646 4 12 \n", + "2 76 1 0.769703 6 4 \n", + "3 76 1 -0.121763 6 4 \n", + "4 76 1 1.526147 6 4 \n", + "\n", + " gender frst_in_family school_urbanicity school_mindset \\\n", + "0 2 1 4 0.334544 \n", + "1 2 1 4 0.334544 \n", + "2 2 0 4 0.334544 \n", + "3 2 0 4 0.334544 \n", + "4 1 0 4 0.334544 \n", + "\n", + " school_achievement school_ethnic_minority school_poverty school_size \n", + "0 0.648586 -1.310927 0.224077 -0.426757 \n", + "1 0.648586 -1.310927 0.224077 -0.426757 \n", + "2 0.648586 -1.310927 0.224077 -0.426757 \n", + "3 0.648586 -1.310927 0.224077 -0.426757 \n", + "4 0.648586 -1.310927 0.224077 -0.426757 " + ] + }, + "execution_count": 82, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "data = pd.read_csv(\"../data/matheus_data/learning_mindset.csv\")\n", + "\n", + "data.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Propensity Score Estimation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install causalml" + ] + }, + { + "cell_type": "code", + "execution_count": 108, + "metadata": {}, + "outputs": [], + "source": [ + "from scipy.special import logit\n", + "import numpy as np\n", + "from causalml.propensity import ElasticNetPropensityModel\n", + "\n", + "categ = [\"ethnicity\",\"gender\",\"school_urbanicity\"]\n", + "cont = [\"school_mindset\",\"school_achievement\",\"school_ethnic_minority\",\"school_poverty\",\"school_size\"]\n", + "X = pd.get_dummies(data[categ + cont], columns=categ, drop_first=True)\n", + "\n", + "pm = ElasticNetPropensityModel(\n", + " random_state=42,\n", + " max_iter=5000\n", + ")\n", + "\n", + "ps = pm.fit_predict(X.values, data[\"intervention\"].values)\n", + "logit_ps = logit(np.clip(ps, 1e-6, 1-1e-6))\n", + "zlogit_ps = (logit_ps - logit_ps.mean()) / logit_ps.std(ddof=1)\n", + "\n", + "df = data.copy()\n", + "df[\"ps\"] = ps\n", + "df[\"ps_logit_z\"] = zlogit_ps" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- **logit(PS) 표준화**: caliper 기준으로 사용 (Rosenbaum & Rubin, 1985)\n", + "- **clip(ε=1e-6)**: PS가 0 또는 1 근처일 때 수치 불안정 방지" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Matching - ATT" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ATT: 0.4418 | match rate: 0.948\n" + ] + } + ], + "source": [ + "from causalml.match import NearestNeighborMatch\n", + "\n", + "df_match = pd.concat([df[[\"intervention\", 'achievement_score', \"ps_logit_z\"]], X], axis=1)\n", + "\n", + "matcher_att = NearestNeighborMatch(\n", + " caliper=0.2,\n", + " replace=False,\n", + " ratio=2,\n", + " shuffle=False,\n", + " random_state=42,\n", + " treatment_to_control=True\n", + ")\n", + "\n", + "matched_att = matcher_att.match(\n", + " data=df_match, treatment_col=\"intervention\", score_cols=[\"ps_logit_z\"]\n", + ")\n", + "\n", + "n_t = (df_match[\"intervention\"]==1).sum()\n", + "n_t_matched = (matched_att[\"intervention\"]==1).sum()\n", + "match_rate = n_t_matched / n_t if n_t else float(\"nan\")\n", + "\n", + "ATT = (\n", + " matched_att.loc[matched_att[\"intervention\"]==1, 'achievement_score'].mean() \n", + " - matched_att.loc[matched_att[\"intervention\"]==0, 'achievement_score'].mean()\n", + ")\n", + "\n", + "print(f\"ATT: {ATT:.4f} | match rate: {match_rate:.3f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- **caliper**: logit(PS) 표준편차의 0.2배를 기본값으로 사용 (Rosenbaum & Rubin, 1985)\n", + " \n", + " 공통지지가 부족하거나 매칭률이 낮을 경우 0.25, 0.30 등으로 완화 가능\n", + "- **replace**: 기본은 비복원(False). 매칭률 확보를 위해 필요 시 복원(True) 허용\n", + "- **ratio**: 1:1에서 시작해 필요 시 1:K로 확장. 조정 기준은 `|SMD|<0.1`과 매칭률\n", + "\n", + "> 세 가지 조정은 분산 감소에 도움되지만, 멀리 있거나 중복된 대조군을 포함해 **bias–variance trade-off**를 동반할 수 있음을 고려해야 합니다.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Balance Check - ATT" + ] + }, + { + "cell_type": "code", + "execution_count": 140, + "metadata": {}, + "outputs": [], + "source": [ + "def _mark_smd(val, thresh=0.1):\n", + " try:\n", + " f = float(val)\n", + " return f\"{f:.4f}\" + (\"*\" if abs(f) >= thresh else \"\")\n", + " except Exception:\n", + " return val\n", + "\n", + "def combine_table_one(table_pre, table_post, smd_star_thresh=0.1):\n", + " table_post = table_post.reindex(table_pre.index)\n", + "\n", + " out = pd.DataFrame({\n", + " \"Control (pre)\": table_pre[\"Control\"],\n", + " \"Treatment (pre)\": table_pre[\"Treatment\"],\n", + " \"SMD (pre)\": table_pre[\"SMD\"].astype(str),\n", + " \"Control (post)\": table_post[\"Control\"],\n", + " \"Treatment (post)\":table_post[\"Treatment\"],\n", + " \"SMD (post)\": table_post[\"SMD\"].astype(str),\n", + " }, index=table_pre.index)\n", + "\n", + " mask_vars = out.index != \"n\"\n", + " out.loc[mask_vars, \"SMD (pre)\"] = out.loc[mask_vars, \"SMD (pre)\"].apply(_mark_smd, thresh=smd_star_thresh)\n", + " out.loc[mask_vars, \"SMD (post)\"] = out.loc[mask_vars, \"SMD (post)\"].apply(_mark_smd, thresh=smd_star_thresh)\n", + " out.index.name = \"Variable\"\n", + " return out" + ] + }, + { + "cell_type": "code", + "execution_count": 141, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Control (pre)Treatment (pre)SMD (pre)Control (post)Treatment (post)SMD (post)
Variable
n7007338464143209
ethnicity_100.02 (0.13)0.02 (0.13)0.01240.02 (0.13)0.02 (0.13)0.0094
ethnicity_110.02 (0.13)0.02 (0.12)-0.01170.02 (0.13)0.02 (0.13)-0.0050
ethnicity_120.03 (0.18)0.03 (0.17)-0.01230.03 (0.18)0.03 (0.17)-0.0109
ethnicity_130.02 (0.14)0.01 (0.11)-0.04920.01 (0.11)0.01 (0.12)0.0192
ethnicity_140.06 (0.24)0.06 (0.24)-0.01700.06 (0.24)0.06 (0.24)-0.0093
ethnicity_150.03 (0.18)0.03 (0.18)0.00810.04 (0.19)0.04 (0.19)-0.0001
ethnicity_20.15 (0.36)0.15 (0.36)-0.00110.15 (0.35)0.15 (0.35)0.0020
ethnicity_30.01 (0.11)0.01 (0.10)-0.01630.01 (0.09)0.01 (0.10)0.0127
ethnicity_40.48 (0.50)0.50 (0.50)0.03670.51 (0.50)0.50 (0.50)-0.0137
ethnicity_50.04 (0.19)0.04 (0.20)0.01380.04 (0.20)0.04 (0.20)-0.0056
ethnicity_60.00 (0.06)0.01 (0.07)0.03360.00 (0.06)0.00 (0.06)0.0102
ethnicity_70.00 (0.06)0.01 (0.07)0.01990.00 (0.06)0.00 (0.06)-0.0102
ethnicity_80.02 (0.13)0.02 (0.14)0.02380.02 (0.14)0.02 (0.14)-0.0001
ethnicity_90.01 (0.12)0.01 (0.10)-0.03160.01 (0.10)0.01 (0.11)0.0014
gender_20.50 (0.50)0.47 (0.50)-0.05430.48 (0.50)0.48 (0.50)0.0044
school_achievement0.04 (0.94)0.09 (0.93)0.05970.08 (0.92)0.04 (0.89)-0.0431
school_ethnic_minority-0.09 (0.97)-0.09 (0.96)-0.0048-0.14 (0.92)-0.11 (0.96)0.0314
school_mindset-0.01 (0.97)-0.10 (0.97)-0.0984-0.07 (0.95)-0.03 (0.91)0.0440
school_poverty-0.04 (0.97)-0.06 (0.96)-0.0249-0.03 (0.95)-0.03 (0.98)-0.0085
school_size-0.05 (1.00)0.02 (1.02)0.0729-0.02 (1.02)-0.01 (1.03)0.0071
school_urbanicity_10.24 (0.43)0.23 (0.42)-0.02880.24 (0.43)0.24 (0.43)-0.0135
school_urbanicity_20.19 (0.39)0.20 (0.40)0.03500.19 (0.39)0.18 (0.39)-0.0163
school_urbanicity_30.15 (0.36)0.13 (0.34)-0.04720.15 (0.36)0.14 (0.35)-0.0359
school_urbanicity_40.34 (0.48)0.36 (0.48)0.02670.34 (0.47)0.37 (0.48)0.0540
\n", + "
" + ], + "text/plain": [ + " Control (pre) Treatment (pre) SMD (pre) Control (post) \\\n", + "Variable \n", + "n 7007 3384 6414 \n", + "ethnicity_10 0.02 (0.13) 0.02 (0.13) 0.0124 0.02 (0.13) \n", + "ethnicity_11 0.02 (0.13) 0.02 (0.12) -0.0117 0.02 (0.13) \n", + "ethnicity_12 0.03 (0.18) 0.03 (0.17) -0.0123 0.03 (0.18) \n", + "ethnicity_13 0.02 (0.14) 0.01 (0.11) -0.0492 0.01 (0.11) \n", + "ethnicity_14 0.06 (0.24) 0.06 (0.24) -0.0170 0.06 (0.24) \n", + "ethnicity_15 0.03 (0.18) 0.03 (0.18) 0.0081 0.04 (0.19) \n", + "ethnicity_2 0.15 (0.36) 0.15 (0.36) -0.0011 0.15 (0.35) \n", + "ethnicity_3 0.01 (0.11) 0.01 (0.10) -0.0163 0.01 (0.09) \n", + "ethnicity_4 0.48 (0.50) 0.50 (0.50) 0.0367 0.51 (0.50) \n", + "ethnicity_5 0.04 (0.19) 0.04 (0.20) 0.0138 0.04 (0.20) \n", + "ethnicity_6 0.00 (0.06) 0.01 (0.07) 0.0336 0.00 (0.06) \n", + "ethnicity_7 0.00 (0.06) 0.01 (0.07) 0.0199 0.00 (0.06) \n", + "ethnicity_8 0.02 (0.13) 0.02 (0.14) 0.0238 0.02 (0.14) \n", + "ethnicity_9 0.01 (0.12) 0.01 (0.10) -0.0316 0.01 (0.10) \n", + "gender_2 0.50 (0.50) 0.47 (0.50) -0.0543 0.48 (0.50) \n", + "school_achievement 0.04 (0.94) 0.09 (0.93) 0.0597 0.08 (0.92) \n", + "school_ethnic_minority -0.09 (0.97) -0.09 (0.96) -0.0048 -0.14 (0.92) \n", + "school_mindset -0.01 (0.97) -0.10 (0.97) -0.0984 -0.07 (0.95) \n", + "school_poverty -0.04 (0.97) -0.06 (0.96) -0.0249 -0.03 (0.95) \n", + "school_size -0.05 (1.00) 0.02 (1.02) 0.0729 -0.02 (1.02) \n", + "school_urbanicity_1 0.24 (0.43) 0.23 (0.42) -0.0288 0.24 (0.43) \n", + "school_urbanicity_2 0.19 (0.39) 0.20 (0.40) 0.0350 0.19 (0.39) \n", + "school_urbanicity_3 0.15 (0.36) 0.13 (0.34) -0.0472 0.15 (0.36) \n", + "school_urbanicity_4 0.34 (0.48) 0.36 (0.48) 0.0267 0.34 (0.47) \n", + "\n", + " Treatment (post) SMD (post) \n", + "Variable \n", + "n 3209 \n", + "ethnicity_10 0.02 (0.13) 0.0094 \n", + "ethnicity_11 0.02 (0.13) -0.0050 \n", + "ethnicity_12 0.03 (0.17) -0.0109 \n", + "ethnicity_13 0.01 (0.12) 0.0192 \n", + "ethnicity_14 0.06 (0.24) -0.0093 \n", + "ethnicity_15 0.04 (0.19) -0.0001 \n", + "ethnicity_2 0.15 (0.35) 0.0020 \n", + "ethnicity_3 0.01 (0.10) 0.0127 \n", + "ethnicity_4 0.50 (0.50) -0.0137 \n", + "ethnicity_5 0.04 (0.20) -0.0056 \n", + "ethnicity_6 0.00 (0.06) 0.0102 \n", + "ethnicity_7 0.00 (0.06) -0.0102 \n", + "ethnicity_8 0.02 (0.14) -0.0001 \n", + "ethnicity_9 0.01 (0.11) 0.0014 \n", + "gender_2 0.48 (0.50) 0.0044 \n", + "school_achievement 0.04 (0.89) -0.0431 \n", + "school_ethnic_minority -0.11 (0.96) 0.0314 \n", + "school_mindset -0.03 (0.91) 0.0440 \n", + "school_poverty -0.03 (0.98) -0.0085 \n", + "school_size -0.01 (1.03) 0.0071 \n", + "school_urbanicity_1 0.24 (0.43) -0.0135 \n", + "school_urbanicity_2 0.18 (0.39) -0.0163 \n", + "school_urbanicity_3 0.14 (0.35) -0.0359 \n", + "school_urbanicity_4 0.37 (0.48) 0.0540 " + ] + }, + "execution_count": 141, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from causalml.match import create_table_one\n", + "\n", + "features_cov = list(X.columns)\n", + "\n", + "table_pre_att = create_table_one(\n", + " data=df_match,\n", + " treatment_col=\"intervention\",\n", + " features=features_cov,\n", + " with_std=True,\n", + " with_counts=True\n", + ")\n", + "\n", + "table_post_att = create_table_one(\n", + " data=matched_att,\n", + " treatment_col=\"intervention\",\n", + " features=features_cov,\n", + " with_std=True,\n", + " with_counts=True\n", + ")\n", + "\n", + "balance_att = combine_table_one(table_pre_att, table_post_att, smd_star_thresh=0.1)\n", + "balance_att" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Matching - ATC" + ] + }, + { + "cell_type": "code", + "execution_count": 142, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ATC: 0.3843 | control match rate: 0.999\n" + ] + } + ], + "source": [ + "from causalml.match import NearestNeighborMatch\n", + "\n", + "matcher_atc = NearestNeighborMatch(\n", + " caliper=0.3,\n", + " replace=True,\n", + " ratio=1,\n", + " shuffle=False,\n", + " random_state=42,\n", + " treatment_to_control=False\n", + ")\n", + "\n", + "matched_atc = matcher_atc.match(\n", + " data=df_match, treatment_col=\"intervention\", score_cols=[\"ps_logit_z\"]\n", + ")\n", + "\n", + "n_c = (df_match[\"intervention\"]==0).sum()\n", + "n_c_matched = (matched_atc[\"intervention\"]==0).sum()\n", + "match_rate = n_c_matched / n_c if n_c else float(\"nan\")\n", + "\n", + "ATC = (\n", + " matched_atc.loc[matched_atc[\"intervention\"]==1, \"achievement_score\"].mean()\n", + " - matched_atc.loc[matched_atc[\"intervention\"]==0, \"achievement_score\"].mean() \n", + ")\n", + "\n", + "print(f\"ATC: {ATC:.4f} | control match rate: {match_rate:.3f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Balance Check - ATC" + ] + }, + { + "cell_type": "code", + "execution_count": 143, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Control (pre)Treatment (pre)SMD (pre)Control (post)Treatment (post)SMD (post)
Variable
n7007338469996999
ethnicity_100.02 (0.13)0.02 (0.13)0.01240.02 (0.13)0.08 (0.28)0.3132*
ethnicity_110.02 (0.13)0.02 (0.12)-0.01170.02 (0.13)0.00 (0.00)-0.1852*
ethnicity_120.03 (0.18)0.03 (0.17)-0.01230.03 (0.18)0.00 (0.00)-0.2571*
ethnicity_130.02 (0.14)0.01 (0.11)-0.04920.02 (0.13)0.00 (0.06)-0.1465*
ethnicity_140.06 (0.24)0.06 (0.24)-0.01700.06 (0.25)0.00 (0.00)-0.3702*
ethnicity_150.03 (0.18)0.03 (0.18)0.00810.03 (0.18)0.07 (0.26)0.1748*
ethnicity_20.15 (0.36)0.15 (0.36)-0.00110.15 (0.36)0.18 (0.39)0.0875
ethnicity_30.01 (0.11)0.01 (0.10)-0.01630.01 (0.11)0.00 (0.00)-0.1521*
ethnicity_40.48 (0.50)0.50 (0.50)0.03670.48 (0.50)0.52 (0.50)0.0852
ethnicity_50.04 (0.19)0.04 (0.20)0.01380.04 (0.19)0.05 (0.21)0.0352
ethnicity_60.00 (0.06)0.01 (0.07)0.03360.00 (0.06)0.00 (0.06)0.0098
ethnicity_70.00 (0.06)0.01 (0.07)0.01990.00 (0.06)0.00 (0.00)-0.0863
ethnicity_80.02 (0.13)0.02 (0.14)0.02380.02 (0.13)0.02 (0.15)0.0458
ethnicity_90.01 (0.12)0.01 (0.10)-0.03160.01 (0.12)0.06 (0.24)0.2440*
gender_20.50 (0.50)0.47 (0.50)-0.05430.50 (0.50)0.53 (0.50)0.0646
school_achievement0.04 (0.94)0.09 (0.93)0.05970.04 (0.94)0.05 (1.00)0.0132
school_ethnic_minority-0.09 (0.97)-0.09 (0.96)-0.0048-0.09 (0.97)-0.14 (1.30)-0.0445
school_mindset-0.01 (0.97)-0.10 (0.97)-0.0984-0.01 (0.97)0.26 (0.65)0.3264*
school_poverty-0.04 (0.97)-0.06 (0.96)-0.0249-0.04 (0.97)0.12 (0.57)0.1968*
school_size-0.05 (1.00)0.02 (1.02)0.0729-0.05 (1.00)-0.19 (1.06)-0.1365*
school_urbanicity_10.24 (0.43)0.23 (0.42)-0.02880.24 (0.43)0.06 (0.24)-0.5145*
school_urbanicity_20.19 (0.39)0.20 (0.40)0.03500.19 (0.39)0.05 (0.21)-0.4412*
school_urbanicity_30.15 (0.36)0.13 (0.34)-0.04720.15 (0.36)0.24 (0.43)0.2238*
school_urbanicity_40.34 (0.48)0.36 (0.48)0.02670.35 (0.48)0.65 (0.48)0.6475*
\n", + "
" + ], + "text/plain": [ + " Control (pre) Treatment (pre) SMD (pre) Control (post) \\\n", + "Variable \n", + "n 7007 3384 6999 \n", + "ethnicity_10 0.02 (0.13) 0.02 (0.13) 0.0124 0.02 (0.13) \n", + "ethnicity_11 0.02 (0.13) 0.02 (0.12) -0.0117 0.02 (0.13) \n", + "ethnicity_12 0.03 (0.18) 0.03 (0.17) -0.0123 0.03 (0.18) \n", + "ethnicity_13 0.02 (0.14) 0.01 (0.11) -0.0492 0.02 (0.13) \n", + "ethnicity_14 0.06 (0.24) 0.06 (0.24) -0.0170 0.06 (0.25) \n", + "ethnicity_15 0.03 (0.18) 0.03 (0.18) 0.0081 0.03 (0.18) \n", + "ethnicity_2 0.15 (0.36) 0.15 (0.36) -0.0011 0.15 (0.36) \n", + "ethnicity_3 0.01 (0.11) 0.01 (0.10) -0.0163 0.01 (0.11) \n", + "ethnicity_4 0.48 (0.50) 0.50 (0.50) 0.0367 0.48 (0.50) \n", + "ethnicity_5 0.04 (0.19) 0.04 (0.20) 0.0138 0.04 (0.19) \n", + "ethnicity_6 0.00 (0.06) 0.01 (0.07) 0.0336 0.00 (0.06) \n", + "ethnicity_7 0.00 (0.06) 0.01 (0.07) 0.0199 0.00 (0.06) \n", + "ethnicity_8 0.02 (0.13) 0.02 (0.14) 0.0238 0.02 (0.13) \n", + "ethnicity_9 0.01 (0.12) 0.01 (0.10) -0.0316 0.01 (0.12) \n", + "gender_2 0.50 (0.50) 0.47 (0.50) -0.0543 0.50 (0.50) \n", + "school_achievement 0.04 (0.94) 0.09 (0.93) 0.0597 0.04 (0.94) \n", + "school_ethnic_minority -0.09 (0.97) -0.09 (0.96) -0.0048 -0.09 (0.97) \n", + "school_mindset -0.01 (0.97) -0.10 (0.97) -0.0984 -0.01 (0.97) \n", + "school_poverty -0.04 (0.97) -0.06 (0.96) -0.0249 -0.04 (0.97) \n", + "school_size -0.05 (1.00) 0.02 (1.02) 0.0729 -0.05 (1.00) \n", + "school_urbanicity_1 0.24 (0.43) 0.23 (0.42) -0.0288 0.24 (0.43) \n", + "school_urbanicity_2 0.19 (0.39) 0.20 (0.40) 0.0350 0.19 (0.39) \n", + "school_urbanicity_3 0.15 (0.36) 0.13 (0.34) -0.0472 0.15 (0.36) \n", + "school_urbanicity_4 0.34 (0.48) 0.36 (0.48) 0.0267 0.35 (0.48) \n", + "\n", + " Treatment (post) SMD (post) \n", + "Variable \n", + "n 6999 \n", + "ethnicity_10 0.08 (0.28) 0.3132* \n", + "ethnicity_11 0.00 (0.00) -0.1852* \n", + "ethnicity_12 0.00 (0.00) -0.2571* \n", + "ethnicity_13 0.00 (0.06) -0.1465* \n", + "ethnicity_14 0.00 (0.00) -0.3702* \n", + "ethnicity_15 0.07 (0.26) 0.1748* \n", + "ethnicity_2 0.18 (0.39) 0.0875 \n", + "ethnicity_3 0.00 (0.00) -0.1521* \n", + "ethnicity_4 0.52 (0.50) 0.0852 \n", + "ethnicity_5 0.05 (0.21) 0.0352 \n", + "ethnicity_6 0.00 (0.06) 0.0098 \n", + "ethnicity_7 0.00 (0.00) -0.0863 \n", + "ethnicity_8 0.02 (0.15) 0.0458 \n", + "ethnicity_9 0.06 (0.24) 0.2440* \n", + "gender_2 0.53 (0.50) 0.0646 \n", + "school_achievement 0.05 (1.00) 0.0132 \n", + "school_ethnic_minority -0.14 (1.30) -0.0445 \n", + "school_mindset 0.26 (0.65) 0.3264* \n", + "school_poverty 0.12 (0.57) 0.1968* \n", + "school_size -0.19 (1.06) -0.1365* \n", + "school_urbanicity_1 0.06 (0.24) -0.5145* \n", + "school_urbanicity_2 0.05 (0.21) -0.4412* \n", + "school_urbanicity_3 0.24 (0.43) 0.2238* \n", + "school_urbanicity_4 0.65 (0.48) 0.6475* " + ] + }, + "execution_count": 143, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from causalml.match import create_table_one\n", + "\n", + "table_pre_atc = create_table_one(\n", + " data=df_match,\n", + " treatment_col=\"intervention\",\n", + " features=features_cov,\n", + " with_std=True,\n", + " with_counts=True\n", + ")\n", + "\n", + "table_post_atc = create_table_one(\n", + " data=matched_atc,\n", + " treatment_col=\"intervention\",\n", + " features=features_cov,\n", + " with_std=True,\n", + " with_counts=True\n", + ")\n", + "\n", + "balance_atc = combine_table_one(table_pre_atc, table_post_atc, smd_star_thresh=0.1)\n", + "balance_atc" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "ATC 매칭은 caliper와 복원 여부에 따라 매칭률과 공변량 균형 사이에 trade-off가 있습니다. \n", + "- 조건이 엄격하면 공통지지 부족으로 매칭이 불가능하고, \n", + "- 조건을 완화하면 반복 사용과 먼 거리 매칭으로 SMD가 악화됩니다.\n", + "\n", + "따라서 처치군과 대조군의 표본 불균형이 큰 경우, 매칭만으로는 ATE 추정에 제약이 있습니다. 모집단 평균 효과를 안정적으로 추정하려면 IPW, AIPW 등 가중치 기반 방법이 더 적합합니다." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ATE" + ] + }, + { + "cell_type": "code", + "execution_count": 144, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ATE: 0.403\n" + ] + } + ], + "source": [ + "p_t = float(df[\"intervention\"].mean())\n", + "ATE = ATT * p_t + ATC * (1 - p_t)\n", + "print(\"ATE:\", round(ATE, 4))" ] }, { @@ -33,7 +1173,7 @@ ], "metadata": { "kernelspec": { - "display_name": "base", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -47,7 +1187,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.12.11" } }, "nbformat": 4,