diff --git a/book/ate/propensity_score_and_dml.ipynb b/book/ate/propensity_score_and_dml.ipynb
index c2e7639..417b71e 100644
--- a/book/ate/propensity_score_and_dml.ipynb
+++ b/book/ate/propensity_score_and_dml.ipynb
@@ -11,9 +11,1149 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "- Matching\n",
- "- IPW, AIPW, Doubly Robust Estimator\n",
- "- Double Machine Learning (비모수 버전의 Regression 처럼 활용 가능)"
+ "## Propensity Score Matching - Binary Treatment"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "미국 공립 고등학교에서 수행된 성장 마인드셋(Growth Mindset) 무작위 연구를 기반으로 생성된 시뮬레이션 데이터입니다.\n",
+ "\n",
+ "- **처치(intervention)**: 성장 마인드 교육 세미나 참여 여부\n",
+ "- **결과(achievement_score)**: 학업 성취도 점수\n",
+ "\n",
+ "그 외 공변량들은 학생과 학교의 배경 특성을 반영합니다.\n",
+ "\n",
+ "**참고**: [Athey & Wager (2019)](https://arxiv.org/pdf/1902.07409)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 82,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " schoolid | \n",
+ " intervention | \n",
+ " achievement_score | \n",
+ " success_expect | \n",
+ " ethnicity | \n",
+ " gender | \n",
+ " frst_in_family | \n",
+ " school_urbanicity | \n",
+ " school_mindset | \n",
+ " school_achievement | \n",
+ " school_ethnic_minority | \n",
+ " school_poverty | \n",
+ " school_size | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 76 | \n",
+ " 1 | \n",
+ " 0.277359 | \n",
+ " 6 | \n",
+ " 4 | \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 4 | \n",
+ " 0.334544 | \n",
+ " 0.648586 | \n",
+ " -1.310927 | \n",
+ " 0.224077 | \n",
+ " -0.426757 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 76 | \n",
+ " 1 | \n",
+ " -0.449646 | \n",
+ " 4 | \n",
+ " 12 | \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 4 | \n",
+ " 0.334544 | \n",
+ " 0.648586 | \n",
+ " -1.310927 | \n",
+ " 0.224077 | \n",
+ " -0.426757 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 76 | \n",
+ " 1 | \n",
+ " 0.769703 | \n",
+ " 6 | \n",
+ " 4 | \n",
+ " 2 | \n",
+ " 0 | \n",
+ " 4 | \n",
+ " 0.334544 | \n",
+ " 0.648586 | \n",
+ " -1.310927 | \n",
+ " 0.224077 | \n",
+ " -0.426757 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 76 | \n",
+ " 1 | \n",
+ " -0.121763 | \n",
+ " 6 | \n",
+ " 4 | \n",
+ " 2 | \n",
+ " 0 | \n",
+ " 4 | \n",
+ " 0.334544 | \n",
+ " 0.648586 | \n",
+ " -1.310927 | \n",
+ " 0.224077 | \n",
+ " -0.426757 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 76 | \n",
+ " 1 | \n",
+ " 1.526147 | \n",
+ " 6 | \n",
+ " 4 | \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 4 | \n",
+ " 0.334544 | \n",
+ " 0.648586 | \n",
+ " -1.310927 | \n",
+ " 0.224077 | \n",
+ " -0.426757 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " schoolid intervention achievement_score success_expect ethnicity \\\n",
+ "0 76 1 0.277359 6 4 \n",
+ "1 76 1 -0.449646 4 12 \n",
+ "2 76 1 0.769703 6 4 \n",
+ "3 76 1 -0.121763 6 4 \n",
+ "4 76 1 1.526147 6 4 \n",
+ "\n",
+ " gender frst_in_family school_urbanicity school_mindset \\\n",
+ "0 2 1 4 0.334544 \n",
+ "1 2 1 4 0.334544 \n",
+ "2 2 0 4 0.334544 \n",
+ "3 2 0 4 0.334544 \n",
+ "4 1 0 4 0.334544 \n",
+ "\n",
+ " school_achievement school_ethnic_minority school_poverty school_size \n",
+ "0 0.648586 -1.310927 0.224077 -0.426757 \n",
+ "1 0.648586 -1.310927 0.224077 -0.426757 \n",
+ "2 0.648586 -1.310927 0.224077 -0.426757 \n",
+ "3 0.648586 -1.310927 0.224077 -0.426757 \n",
+ "4 0.648586 -1.310927 0.224077 -0.426757 "
+ ]
+ },
+ "execution_count": 82,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "data = pd.read_csv(\"../data/matheus_data/learning_mindset.csv\")\n",
+ "\n",
+ "data.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Propensity Score Estimation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install causalml"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 108,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from scipy.special import logit\n",
+ "import numpy as np\n",
+ "from causalml.propensity import ElasticNetPropensityModel\n",
+ "\n",
+ "categ = [\"ethnicity\",\"gender\",\"school_urbanicity\"]\n",
+ "cont = [\"school_mindset\",\"school_achievement\",\"school_ethnic_minority\",\"school_poverty\",\"school_size\"]\n",
+ "X = pd.get_dummies(data[categ + cont], columns=categ, drop_first=True)\n",
+ "\n",
+ "pm = ElasticNetPropensityModel(\n",
+ " random_state=42,\n",
+ " max_iter=5000\n",
+ ")\n",
+ "\n",
+ "ps = pm.fit_predict(X.values, data[\"intervention\"].values)\n",
+ "logit_ps = logit(np.clip(ps, 1e-6, 1-1e-6))\n",
+ "zlogit_ps = (logit_ps - logit_ps.mean()) / logit_ps.std(ddof=1)\n",
+ "\n",
+ "df = data.copy()\n",
+ "df[\"ps\"] = ps\n",
+ "df[\"ps_logit_z\"] = zlogit_ps"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- **logit(PS) 표준화**: caliper 기준으로 사용 (Rosenbaum & Rubin, 1985)\n",
+ "- **clip(ε=1e-6)**: PS가 0 또는 1 근처일 때 수치 불안정 방지"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Matching - ATT"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 139,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "ATT: 0.4418 | match rate: 0.948\n"
+ ]
+ }
+ ],
+ "source": [
+ "from causalml.match import NearestNeighborMatch\n",
+ "\n",
+ "df_match = pd.concat([df[[\"intervention\", 'achievement_score', \"ps_logit_z\"]], X], axis=1)\n",
+ "\n",
+ "matcher_att = NearestNeighborMatch(\n",
+ " caliper=0.2,\n",
+ " replace=False,\n",
+ " ratio=2,\n",
+ " shuffle=False,\n",
+ " random_state=42,\n",
+ " treatment_to_control=True\n",
+ ")\n",
+ "\n",
+ "matched_att = matcher_att.match(\n",
+ " data=df_match, treatment_col=\"intervention\", score_cols=[\"ps_logit_z\"]\n",
+ ")\n",
+ "\n",
+ "n_t = (df_match[\"intervention\"]==1).sum()\n",
+ "n_t_matched = (matched_att[\"intervention\"]==1).sum()\n",
+ "match_rate = n_t_matched / n_t if n_t else float(\"nan\")\n",
+ "\n",
+ "ATT = (\n",
+ " matched_att.loc[matched_att[\"intervention\"]==1, 'achievement_score'].mean() \n",
+ " - matched_att.loc[matched_att[\"intervention\"]==0, 'achievement_score'].mean()\n",
+ ")\n",
+ "\n",
+ "print(f\"ATT: {ATT:.4f} | match rate: {match_rate:.3f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- **caliper**: logit(PS) 표준편차의 0.2배를 기본값으로 사용 (Rosenbaum & Rubin, 1985)\n",
+ " \n",
+ " 공통지지가 부족하거나 매칭률이 낮을 경우 0.25, 0.30 등으로 완화 가능\n",
+ "- **replace**: 기본은 비복원(False). 매칭률 확보를 위해 필요 시 복원(True) 허용\n",
+ "- **ratio**: 1:1에서 시작해 필요 시 1:K로 확장. 조정 기준은 `|SMD|<0.1`과 매칭률\n",
+ "\n",
+ "> 세 가지 조정은 분산 감소에 도움되지만, 멀리 있거나 중복된 대조군을 포함해 **bias–variance trade-off**를 동반할 수 있음을 고려해야 합니다.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Balance Check - ATT"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 140,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def _mark_smd(val, thresh=0.1):\n",
+ " try:\n",
+ " f = float(val)\n",
+ " return f\"{f:.4f}\" + (\"*\" if abs(f) >= thresh else \"\")\n",
+ " except Exception:\n",
+ " return val\n",
+ "\n",
+ "def combine_table_one(table_pre, table_post, smd_star_thresh=0.1):\n",
+ " table_post = table_post.reindex(table_pre.index)\n",
+ "\n",
+ " out = pd.DataFrame({\n",
+ " \"Control (pre)\": table_pre[\"Control\"],\n",
+ " \"Treatment (pre)\": table_pre[\"Treatment\"],\n",
+ " \"SMD (pre)\": table_pre[\"SMD\"].astype(str),\n",
+ " \"Control (post)\": table_post[\"Control\"],\n",
+ " \"Treatment (post)\":table_post[\"Treatment\"],\n",
+ " \"SMD (post)\": table_post[\"SMD\"].astype(str),\n",
+ " }, index=table_pre.index)\n",
+ "\n",
+ " mask_vars = out.index != \"n\"\n",
+ " out.loc[mask_vars, \"SMD (pre)\"] = out.loc[mask_vars, \"SMD (pre)\"].apply(_mark_smd, thresh=smd_star_thresh)\n",
+ " out.loc[mask_vars, \"SMD (post)\"] = out.loc[mask_vars, \"SMD (post)\"].apply(_mark_smd, thresh=smd_star_thresh)\n",
+ " out.index.name = \"Variable\"\n",
+ " return out"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 141,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Control (pre) | \n",
+ " Treatment (pre) | \n",
+ " SMD (pre) | \n",
+ " Control (post) | \n",
+ " Treatment (post) | \n",
+ " SMD (post) | \n",
+ "
\n",
+ " \n",
+ " | Variable | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | n | \n",
+ " 7007 | \n",
+ " 3384 | \n",
+ " | \n",
+ " 6414 | \n",
+ " 3209 | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_10 | \n",
+ " 0.02 (0.13) | \n",
+ " 0.02 (0.13) | \n",
+ " 0.0124 | \n",
+ " 0.02 (0.13) | \n",
+ " 0.02 (0.13) | \n",
+ " 0.0094 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_11 | \n",
+ " 0.02 (0.13) | \n",
+ " 0.02 (0.12) | \n",
+ " -0.0117 | \n",
+ " 0.02 (0.13) | \n",
+ " 0.02 (0.13) | \n",
+ " -0.0050 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_12 | \n",
+ " 0.03 (0.18) | \n",
+ " 0.03 (0.17) | \n",
+ " -0.0123 | \n",
+ " 0.03 (0.18) | \n",
+ " 0.03 (0.17) | \n",
+ " -0.0109 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_13 | \n",
+ " 0.02 (0.14) | \n",
+ " 0.01 (0.11) | \n",
+ " -0.0492 | \n",
+ " 0.01 (0.11) | \n",
+ " 0.01 (0.12) | \n",
+ " 0.0192 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_14 | \n",
+ " 0.06 (0.24) | \n",
+ " 0.06 (0.24) | \n",
+ " -0.0170 | \n",
+ " 0.06 (0.24) | \n",
+ " 0.06 (0.24) | \n",
+ " -0.0093 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_15 | \n",
+ " 0.03 (0.18) | \n",
+ " 0.03 (0.18) | \n",
+ " 0.0081 | \n",
+ " 0.04 (0.19) | \n",
+ " 0.04 (0.19) | \n",
+ " -0.0001 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_2 | \n",
+ " 0.15 (0.36) | \n",
+ " 0.15 (0.36) | \n",
+ " -0.0011 | \n",
+ " 0.15 (0.35) | \n",
+ " 0.15 (0.35) | \n",
+ " 0.0020 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_3 | \n",
+ " 0.01 (0.11) | \n",
+ " 0.01 (0.10) | \n",
+ " -0.0163 | \n",
+ " 0.01 (0.09) | \n",
+ " 0.01 (0.10) | \n",
+ " 0.0127 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_4 | \n",
+ " 0.48 (0.50) | \n",
+ " 0.50 (0.50) | \n",
+ " 0.0367 | \n",
+ " 0.51 (0.50) | \n",
+ " 0.50 (0.50) | \n",
+ " -0.0137 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_5 | \n",
+ " 0.04 (0.19) | \n",
+ " 0.04 (0.20) | \n",
+ " 0.0138 | \n",
+ " 0.04 (0.20) | \n",
+ " 0.04 (0.20) | \n",
+ " -0.0056 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_6 | \n",
+ " 0.00 (0.06) | \n",
+ " 0.01 (0.07) | \n",
+ " 0.0336 | \n",
+ " 0.00 (0.06) | \n",
+ " 0.00 (0.06) | \n",
+ " 0.0102 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_7 | \n",
+ " 0.00 (0.06) | \n",
+ " 0.01 (0.07) | \n",
+ " 0.0199 | \n",
+ " 0.00 (0.06) | \n",
+ " 0.00 (0.06) | \n",
+ " -0.0102 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_8 | \n",
+ " 0.02 (0.13) | \n",
+ " 0.02 (0.14) | \n",
+ " 0.0238 | \n",
+ " 0.02 (0.14) | \n",
+ " 0.02 (0.14) | \n",
+ " -0.0001 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_9 | \n",
+ " 0.01 (0.12) | \n",
+ " 0.01 (0.10) | \n",
+ " -0.0316 | \n",
+ " 0.01 (0.10) | \n",
+ " 0.01 (0.11) | \n",
+ " 0.0014 | \n",
+ "
\n",
+ " \n",
+ " | gender_2 | \n",
+ " 0.50 (0.50) | \n",
+ " 0.47 (0.50) | \n",
+ " -0.0543 | \n",
+ " 0.48 (0.50) | \n",
+ " 0.48 (0.50) | \n",
+ " 0.0044 | \n",
+ "
\n",
+ " \n",
+ " | school_achievement | \n",
+ " 0.04 (0.94) | \n",
+ " 0.09 (0.93) | \n",
+ " 0.0597 | \n",
+ " 0.08 (0.92) | \n",
+ " 0.04 (0.89) | \n",
+ " -0.0431 | \n",
+ "
\n",
+ " \n",
+ " | school_ethnic_minority | \n",
+ " -0.09 (0.97) | \n",
+ " -0.09 (0.96) | \n",
+ " -0.0048 | \n",
+ " -0.14 (0.92) | \n",
+ " -0.11 (0.96) | \n",
+ " 0.0314 | \n",
+ "
\n",
+ " \n",
+ " | school_mindset | \n",
+ " -0.01 (0.97) | \n",
+ " -0.10 (0.97) | \n",
+ " -0.0984 | \n",
+ " -0.07 (0.95) | \n",
+ " -0.03 (0.91) | \n",
+ " 0.0440 | \n",
+ "
\n",
+ " \n",
+ " | school_poverty | \n",
+ " -0.04 (0.97) | \n",
+ " -0.06 (0.96) | \n",
+ " -0.0249 | \n",
+ " -0.03 (0.95) | \n",
+ " -0.03 (0.98) | \n",
+ " -0.0085 | \n",
+ "
\n",
+ " \n",
+ " | school_size | \n",
+ " -0.05 (1.00) | \n",
+ " 0.02 (1.02) | \n",
+ " 0.0729 | \n",
+ " -0.02 (1.02) | \n",
+ " -0.01 (1.03) | \n",
+ " 0.0071 | \n",
+ "
\n",
+ " \n",
+ " | school_urbanicity_1 | \n",
+ " 0.24 (0.43) | \n",
+ " 0.23 (0.42) | \n",
+ " -0.0288 | \n",
+ " 0.24 (0.43) | \n",
+ " 0.24 (0.43) | \n",
+ " -0.0135 | \n",
+ "
\n",
+ " \n",
+ " | school_urbanicity_2 | \n",
+ " 0.19 (0.39) | \n",
+ " 0.20 (0.40) | \n",
+ " 0.0350 | \n",
+ " 0.19 (0.39) | \n",
+ " 0.18 (0.39) | \n",
+ " -0.0163 | \n",
+ "
\n",
+ " \n",
+ " | school_urbanicity_3 | \n",
+ " 0.15 (0.36) | \n",
+ " 0.13 (0.34) | \n",
+ " -0.0472 | \n",
+ " 0.15 (0.36) | \n",
+ " 0.14 (0.35) | \n",
+ " -0.0359 | \n",
+ "
\n",
+ " \n",
+ " | school_urbanicity_4 | \n",
+ " 0.34 (0.48) | \n",
+ " 0.36 (0.48) | \n",
+ " 0.0267 | \n",
+ " 0.34 (0.47) | \n",
+ " 0.37 (0.48) | \n",
+ " 0.0540 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Control (pre) Treatment (pre) SMD (pre) Control (post) \\\n",
+ "Variable \n",
+ "n 7007 3384 6414 \n",
+ "ethnicity_10 0.02 (0.13) 0.02 (0.13) 0.0124 0.02 (0.13) \n",
+ "ethnicity_11 0.02 (0.13) 0.02 (0.12) -0.0117 0.02 (0.13) \n",
+ "ethnicity_12 0.03 (0.18) 0.03 (0.17) -0.0123 0.03 (0.18) \n",
+ "ethnicity_13 0.02 (0.14) 0.01 (0.11) -0.0492 0.01 (0.11) \n",
+ "ethnicity_14 0.06 (0.24) 0.06 (0.24) -0.0170 0.06 (0.24) \n",
+ "ethnicity_15 0.03 (0.18) 0.03 (0.18) 0.0081 0.04 (0.19) \n",
+ "ethnicity_2 0.15 (0.36) 0.15 (0.36) -0.0011 0.15 (0.35) \n",
+ "ethnicity_3 0.01 (0.11) 0.01 (0.10) -0.0163 0.01 (0.09) \n",
+ "ethnicity_4 0.48 (0.50) 0.50 (0.50) 0.0367 0.51 (0.50) \n",
+ "ethnicity_5 0.04 (0.19) 0.04 (0.20) 0.0138 0.04 (0.20) \n",
+ "ethnicity_6 0.00 (0.06) 0.01 (0.07) 0.0336 0.00 (0.06) \n",
+ "ethnicity_7 0.00 (0.06) 0.01 (0.07) 0.0199 0.00 (0.06) \n",
+ "ethnicity_8 0.02 (0.13) 0.02 (0.14) 0.0238 0.02 (0.14) \n",
+ "ethnicity_9 0.01 (0.12) 0.01 (0.10) -0.0316 0.01 (0.10) \n",
+ "gender_2 0.50 (0.50) 0.47 (0.50) -0.0543 0.48 (0.50) \n",
+ "school_achievement 0.04 (0.94) 0.09 (0.93) 0.0597 0.08 (0.92) \n",
+ "school_ethnic_minority -0.09 (0.97) -0.09 (0.96) -0.0048 -0.14 (0.92) \n",
+ "school_mindset -0.01 (0.97) -0.10 (0.97) -0.0984 -0.07 (0.95) \n",
+ "school_poverty -0.04 (0.97) -0.06 (0.96) -0.0249 -0.03 (0.95) \n",
+ "school_size -0.05 (1.00) 0.02 (1.02) 0.0729 -0.02 (1.02) \n",
+ "school_urbanicity_1 0.24 (0.43) 0.23 (0.42) -0.0288 0.24 (0.43) \n",
+ "school_urbanicity_2 0.19 (0.39) 0.20 (0.40) 0.0350 0.19 (0.39) \n",
+ "school_urbanicity_3 0.15 (0.36) 0.13 (0.34) -0.0472 0.15 (0.36) \n",
+ "school_urbanicity_4 0.34 (0.48) 0.36 (0.48) 0.0267 0.34 (0.47) \n",
+ "\n",
+ " Treatment (post) SMD (post) \n",
+ "Variable \n",
+ "n 3209 \n",
+ "ethnicity_10 0.02 (0.13) 0.0094 \n",
+ "ethnicity_11 0.02 (0.13) -0.0050 \n",
+ "ethnicity_12 0.03 (0.17) -0.0109 \n",
+ "ethnicity_13 0.01 (0.12) 0.0192 \n",
+ "ethnicity_14 0.06 (0.24) -0.0093 \n",
+ "ethnicity_15 0.04 (0.19) -0.0001 \n",
+ "ethnicity_2 0.15 (0.35) 0.0020 \n",
+ "ethnicity_3 0.01 (0.10) 0.0127 \n",
+ "ethnicity_4 0.50 (0.50) -0.0137 \n",
+ "ethnicity_5 0.04 (0.20) -0.0056 \n",
+ "ethnicity_6 0.00 (0.06) 0.0102 \n",
+ "ethnicity_7 0.00 (0.06) -0.0102 \n",
+ "ethnicity_8 0.02 (0.14) -0.0001 \n",
+ "ethnicity_9 0.01 (0.11) 0.0014 \n",
+ "gender_2 0.48 (0.50) 0.0044 \n",
+ "school_achievement 0.04 (0.89) -0.0431 \n",
+ "school_ethnic_minority -0.11 (0.96) 0.0314 \n",
+ "school_mindset -0.03 (0.91) 0.0440 \n",
+ "school_poverty -0.03 (0.98) -0.0085 \n",
+ "school_size -0.01 (1.03) 0.0071 \n",
+ "school_urbanicity_1 0.24 (0.43) -0.0135 \n",
+ "school_urbanicity_2 0.18 (0.39) -0.0163 \n",
+ "school_urbanicity_3 0.14 (0.35) -0.0359 \n",
+ "school_urbanicity_4 0.37 (0.48) 0.0540 "
+ ]
+ },
+ "execution_count": 141,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from causalml.match import create_table_one\n",
+ "\n",
+ "features_cov = list(X.columns)\n",
+ "\n",
+ "table_pre_att = create_table_one(\n",
+ " data=df_match,\n",
+ " treatment_col=\"intervention\",\n",
+ " features=features_cov,\n",
+ " with_std=True,\n",
+ " with_counts=True\n",
+ ")\n",
+ "\n",
+ "table_post_att = create_table_one(\n",
+ " data=matched_att,\n",
+ " treatment_col=\"intervention\",\n",
+ " features=features_cov,\n",
+ " with_std=True,\n",
+ " with_counts=True\n",
+ ")\n",
+ "\n",
+ "balance_att = combine_table_one(table_pre_att, table_post_att, smd_star_thresh=0.1)\n",
+ "balance_att"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Matching - ATC"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 142,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "ATC: 0.3843 | control match rate: 0.999\n"
+ ]
+ }
+ ],
+ "source": [
+ "from causalml.match import NearestNeighborMatch\n",
+ "\n",
+ "matcher_atc = NearestNeighborMatch(\n",
+ " caliper=0.3,\n",
+ " replace=True,\n",
+ " ratio=1,\n",
+ " shuffle=False,\n",
+ " random_state=42,\n",
+ " treatment_to_control=False\n",
+ ")\n",
+ "\n",
+ "matched_atc = matcher_atc.match(\n",
+ " data=df_match, treatment_col=\"intervention\", score_cols=[\"ps_logit_z\"]\n",
+ ")\n",
+ "\n",
+ "n_c = (df_match[\"intervention\"]==0).sum()\n",
+ "n_c_matched = (matched_atc[\"intervention\"]==0).sum()\n",
+ "match_rate = n_c_matched / n_c if n_c else float(\"nan\")\n",
+ "\n",
+ "ATC = (\n",
+ " matched_atc.loc[matched_atc[\"intervention\"]==1, \"achievement_score\"].mean()\n",
+ " - matched_atc.loc[matched_atc[\"intervention\"]==0, \"achievement_score\"].mean() \n",
+ ")\n",
+ "\n",
+ "print(f\"ATC: {ATC:.4f} | control match rate: {match_rate:.3f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Balance Check - ATC"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 143,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Control (pre) | \n",
+ " Treatment (pre) | \n",
+ " SMD (pre) | \n",
+ " Control (post) | \n",
+ " Treatment (post) | \n",
+ " SMD (post) | \n",
+ "
\n",
+ " \n",
+ " | Variable | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | n | \n",
+ " 7007 | \n",
+ " 3384 | \n",
+ " | \n",
+ " 6999 | \n",
+ " 6999 | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_10 | \n",
+ " 0.02 (0.13) | \n",
+ " 0.02 (0.13) | \n",
+ " 0.0124 | \n",
+ " 0.02 (0.13) | \n",
+ " 0.08 (0.28) | \n",
+ " 0.3132* | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_11 | \n",
+ " 0.02 (0.13) | \n",
+ " 0.02 (0.12) | \n",
+ " -0.0117 | \n",
+ " 0.02 (0.13) | \n",
+ " 0.00 (0.00) | \n",
+ " -0.1852* | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_12 | \n",
+ " 0.03 (0.18) | \n",
+ " 0.03 (0.17) | \n",
+ " -0.0123 | \n",
+ " 0.03 (0.18) | \n",
+ " 0.00 (0.00) | \n",
+ " -0.2571* | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_13 | \n",
+ " 0.02 (0.14) | \n",
+ " 0.01 (0.11) | \n",
+ " -0.0492 | \n",
+ " 0.02 (0.13) | \n",
+ " 0.00 (0.06) | \n",
+ " -0.1465* | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_14 | \n",
+ " 0.06 (0.24) | \n",
+ " 0.06 (0.24) | \n",
+ " -0.0170 | \n",
+ " 0.06 (0.25) | \n",
+ " 0.00 (0.00) | \n",
+ " -0.3702* | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_15 | \n",
+ " 0.03 (0.18) | \n",
+ " 0.03 (0.18) | \n",
+ " 0.0081 | \n",
+ " 0.03 (0.18) | \n",
+ " 0.07 (0.26) | \n",
+ " 0.1748* | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_2 | \n",
+ " 0.15 (0.36) | \n",
+ " 0.15 (0.36) | \n",
+ " -0.0011 | \n",
+ " 0.15 (0.36) | \n",
+ " 0.18 (0.39) | \n",
+ " 0.0875 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_3 | \n",
+ " 0.01 (0.11) | \n",
+ " 0.01 (0.10) | \n",
+ " -0.0163 | \n",
+ " 0.01 (0.11) | \n",
+ " 0.00 (0.00) | \n",
+ " -0.1521* | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_4 | \n",
+ " 0.48 (0.50) | \n",
+ " 0.50 (0.50) | \n",
+ " 0.0367 | \n",
+ " 0.48 (0.50) | \n",
+ " 0.52 (0.50) | \n",
+ " 0.0852 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_5 | \n",
+ " 0.04 (0.19) | \n",
+ " 0.04 (0.20) | \n",
+ " 0.0138 | \n",
+ " 0.04 (0.19) | \n",
+ " 0.05 (0.21) | \n",
+ " 0.0352 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_6 | \n",
+ " 0.00 (0.06) | \n",
+ " 0.01 (0.07) | \n",
+ " 0.0336 | \n",
+ " 0.00 (0.06) | \n",
+ " 0.00 (0.06) | \n",
+ " 0.0098 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_7 | \n",
+ " 0.00 (0.06) | \n",
+ " 0.01 (0.07) | \n",
+ " 0.0199 | \n",
+ " 0.00 (0.06) | \n",
+ " 0.00 (0.00) | \n",
+ " -0.0863 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_8 | \n",
+ " 0.02 (0.13) | \n",
+ " 0.02 (0.14) | \n",
+ " 0.0238 | \n",
+ " 0.02 (0.13) | \n",
+ " 0.02 (0.15) | \n",
+ " 0.0458 | \n",
+ "
\n",
+ " \n",
+ " | ethnicity_9 | \n",
+ " 0.01 (0.12) | \n",
+ " 0.01 (0.10) | \n",
+ " -0.0316 | \n",
+ " 0.01 (0.12) | \n",
+ " 0.06 (0.24) | \n",
+ " 0.2440* | \n",
+ "
\n",
+ " \n",
+ " | gender_2 | \n",
+ " 0.50 (0.50) | \n",
+ " 0.47 (0.50) | \n",
+ " -0.0543 | \n",
+ " 0.50 (0.50) | \n",
+ " 0.53 (0.50) | \n",
+ " 0.0646 | \n",
+ "
\n",
+ " \n",
+ " | school_achievement | \n",
+ " 0.04 (0.94) | \n",
+ " 0.09 (0.93) | \n",
+ " 0.0597 | \n",
+ " 0.04 (0.94) | \n",
+ " 0.05 (1.00) | \n",
+ " 0.0132 | \n",
+ "
\n",
+ " \n",
+ " | school_ethnic_minority | \n",
+ " -0.09 (0.97) | \n",
+ " -0.09 (0.96) | \n",
+ " -0.0048 | \n",
+ " -0.09 (0.97) | \n",
+ " -0.14 (1.30) | \n",
+ " -0.0445 | \n",
+ "
\n",
+ " \n",
+ " | school_mindset | \n",
+ " -0.01 (0.97) | \n",
+ " -0.10 (0.97) | \n",
+ " -0.0984 | \n",
+ " -0.01 (0.97) | \n",
+ " 0.26 (0.65) | \n",
+ " 0.3264* | \n",
+ "
\n",
+ " \n",
+ " | school_poverty | \n",
+ " -0.04 (0.97) | \n",
+ " -0.06 (0.96) | \n",
+ " -0.0249 | \n",
+ " -0.04 (0.97) | \n",
+ " 0.12 (0.57) | \n",
+ " 0.1968* | \n",
+ "
\n",
+ " \n",
+ " | school_size | \n",
+ " -0.05 (1.00) | \n",
+ " 0.02 (1.02) | \n",
+ " 0.0729 | \n",
+ " -0.05 (1.00) | \n",
+ " -0.19 (1.06) | \n",
+ " -0.1365* | \n",
+ "
\n",
+ " \n",
+ " | school_urbanicity_1 | \n",
+ " 0.24 (0.43) | \n",
+ " 0.23 (0.42) | \n",
+ " -0.0288 | \n",
+ " 0.24 (0.43) | \n",
+ " 0.06 (0.24) | \n",
+ " -0.5145* | \n",
+ "
\n",
+ " \n",
+ " | school_urbanicity_2 | \n",
+ " 0.19 (0.39) | \n",
+ " 0.20 (0.40) | \n",
+ " 0.0350 | \n",
+ " 0.19 (0.39) | \n",
+ " 0.05 (0.21) | \n",
+ " -0.4412* | \n",
+ "
\n",
+ " \n",
+ " | school_urbanicity_3 | \n",
+ " 0.15 (0.36) | \n",
+ " 0.13 (0.34) | \n",
+ " -0.0472 | \n",
+ " 0.15 (0.36) | \n",
+ " 0.24 (0.43) | \n",
+ " 0.2238* | \n",
+ "
\n",
+ " \n",
+ " | school_urbanicity_4 | \n",
+ " 0.34 (0.48) | \n",
+ " 0.36 (0.48) | \n",
+ " 0.0267 | \n",
+ " 0.35 (0.48) | \n",
+ " 0.65 (0.48) | \n",
+ " 0.6475* | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Control (pre) Treatment (pre) SMD (pre) Control (post) \\\n",
+ "Variable \n",
+ "n 7007 3384 6999 \n",
+ "ethnicity_10 0.02 (0.13) 0.02 (0.13) 0.0124 0.02 (0.13) \n",
+ "ethnicity_11 0.02 (0.13) 0.02 (0.12) -0.0117 0.02 (0.13) \n",
+ "ethnicity_12 0.03 (0.18) 0.03 (0.17) -0.0123 0.03 (0.18) \n",
+ "ethnicity_13 0.02 (0.14) 0.01 (0.11) -0.0492 0.02 (0.13) \n",
+ "ethnicity_14 0.06 (0.24) 0.06 (0.24) -0.0170 0.06 (0.25) \n",
+ "ethnicity_15 0.03 (0.18) 0.03 (0.18) 0.0081 0.03 (0.18) \n",
+ "ethnicity_2 0.15 (0.36) 0.15 (0.36) -0.0011 0.15 (0.36) \n",
+ "ethnicity_3 0.01 (0.11) 0.01 (0.10) -0.0163 0.01 (0.11) \n",
+ "ethnicity_4 0.48 (0.50) 0.50 (0.50) 0.0367 0.48 (0.50) \n",
+ "ethnicity_5 0.04 (0.19) 0.04 (0.20) 0.0138 0.04 (0.19) \n",
+ "ethnicity_6 0.00 (0.06) 0.01 (0.07) 0.0336 0.00 (0.06) \n",
+ "ethnicity_7 0.00 (0.06) 0.01 (0.07) 0.0199 0.00 (0.06) \n",
+ "ethnicity_8 0.02 (0.13) 0.02 (0.14) 0.0238 0.02 (0.13) \n",
+ "ethnicity_9 0.01 (0.12) 0.01 (0.10) -0.0316 0.01 (0.12) \n",
+ "gender_2 0.50 (0.50) 0.47 (0.50) -0.0543 0.50 (0.50) \n",
+ "school_achievement 0.04 (0.94) 0.09 (0.93) 0.0597 0.04 (0.94) \n",
+ "school_ethnic_minority -0.09 (0.97) -0.09 (0.96) -0.0048 -0.09 (0.97) \n",
+ "school_mindset -0.01 (0.97) -0.10 (0.97) -0.0984 -0.01 (0.97) \n",
+ "school_poverty -0.04 (0.97) -0.06 (0.96) -0.0249 -0.04 (0.97) \n",
+ "school_size -0.05 (1.00) 0.02 (1.02) 0.0729 -0.05 (1.00) \n",
+ "school_urbanicity_1 0.24 (0.43) 0.23 (0.42) -0.0288 0.24 (0.43) \n",
+ "school_urbanicity_2 0.19 (0.39) 0.20 (0.40) 0.0350 0.19 (0.39) \n",
+ "school_urbanicity_3 0.15 (0.36) 0.13 (0.34) -0.0472 0.15 (0.36) \n",
+ "school_urbanicity_4 0.34 (0.48) 0.36 (0.48) 0.0267 0.35 (0.48) \n",
+ "\n",
+ " Treatment (post) SMD (post) \n",
+ "Variable \n",
+ "n 6999 \n",
+ "ethnicity_10 0.08 (0.28) 0.3132* \n",
+ "ethnicity_11 0.00 (0.00) -0.1852* \n",
+ "ethnicity_12 0.00 (0.00) -0.2571* \n",
+ "ethnicity_13 0.00 (0.06) -0.1465* \n",
+ "ethnicity_14 0.00 (0.00) -0.3702* \n",
+ "ethnicity_15 0.07 (0.26) 0.1748* \n",
+ "ethnicity_2 0.18 (0.39) 0.0875 \n",
+ "ethnicity_3 0.00 (0.00) -0.1521* \n",
+ "ethnicity_4 0.52 (0.50) 0.0852 \n",
+ "ethnicity_5 0.05 (0.21) 0.0352 \n",
+ "ethnicity_6 0.00 (0.06) 0.0098 \n",
+ "ethnicity_7 0.00 (0.00) -0.0863 \n",
+ "ethnicity_8 0.02 (0.15) 0.0458 \n",
+ "ethnicity_9 0.06 (0.24) 0.2440* \n",
+ "gender_2 0.53 (0.50) 0.0646 \n",
+ "school_achievement 0.05 (1.00) 0.0132 \n",
+ "school_ethnic_minority -0.14 (1.30) -0.0445 \n",
+ "school_mindset 0.26 (0.65) 0.3264* \n",
+ "school_poverty 0.12 (0.57) 0.1968* \n",
+ "school_size -0.19 (1.06) -0.1365* \n",
+ "school_urbanicity_1 0.06 (0.24) -0.5145* \n",
+ "school_urbanicity_2 0.05 (0.21) -0.4412* \n",
+ "school_urbanicity_3 0.24 (0.43) 0.2238* \n",
+ "school_urbanicity_4 0.65 (0.48) 0.6475* "
+ ]
+ },
+ "execution_count": 143,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from causalml.match import create_table_one\n",
+ "\n",
+ "table_pre_atc = create_table_one(\n",
+ " data=df_match,\n",
+ " treatment_col=\"intervention\",\n",
+ " features=features_cov,\n",
+ " with_std=True,\n",
+ " with_counts=True\n",
+ ")\n",
+ "\n",
+ "table_post_atc = create_table_one(\n",
+ " data=matched_atc,\n",
+ " treatment_col=\"intervention\",\n",
+ " features=features_cov,\n",
+ " with_std=True,\n",
+ " with_counts=True\n",
+ ")\n",
+ "\n",
+ "balance_atc = combine_table_one(table_pre_atc, table_post_atc, smd_star_thresh=0.1)\n",
+ "balance_atc"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "ATC 매칭은 caliper와 복원 여부에 따라 매칭률과 공변량 균형 사이에 trade-off가 있습니다. \n",
+ "- 조건이 엄격하면 공통지지 부족으로 매칭이 불가능하고, \n",
+ "- 조건을 완화하면 반복 사용과 먼 거리 매칭으로 SMD가 악화됩니다.\n",
+ "\n",
+ "따라서 처치군과 대조군의 표본 불균형이 큰 경우, 매칭만으로는 ATE 추정에 제약이 있습니다. 모집단 평균 효과를 안정적으로 추정하려면 IPW, AIPW 등 가중치 기반 방법이 더 적합합니다."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### ATE"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 144,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "ATE: 0.403\n"
+ ]
+ }
+ ],
+ "source": [
+ "p_t = float(df[\"intervention\"].mean())\n",
+ "ATE = ATT * p_t + ATC * (1 - p_t)\n",
+ "print(\"ATE:\", round(ATE, 4))"
]
},
{
@@ -33,7 +1173,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "base",
+ "display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -47,7 +1187,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.7"
+ "version": "3.12.11"
}
},
"nbformat": 4,