Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions book/cate_and_policy/policy_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,71 @@
" async>\n",
"</script>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 라이브러리와 난수 고정\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Rectangle\n",
"rng = np.random.default_rng(42)\n",
"\n",
"# 데이터 생성\n",
"n = 1000\n",
"p = 4\n",
"X = rng.random((n, p)) # 0~1 균등분포 공변량 4개\n",
"W = rng.binomial(1, 0.5, size=n) # 무작위 처치(0/1), 확률 0.5\n",
"Y = 0.5*(X[:, 0] - 0.5) + (X[:, 1] - 0.5)*W + 0.1*rng.normal(size=n)\n",
"\n",
"# 시각화\n",
"y_norm = 1 - (Y - Y.min())/(Y.max() - Y.min()) # 0~1로 정규화\n",
"gray_colors = np.array([str(v) for v in y_norm])\n",
"\n",
"plt.scatter(X[:, 0], X[:, 1], c=gray_colors, s=60, marker='o',\n",
" edgecolors='k', linewidths=0.5)\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Rectangle\n",
"\n",
"plt.figure(figsize=(6, 5))\n",
"\n",
"# 1) 구역 칠하기 (사각형 3개)\n",
"col_treat = (0.25, 0.69, 0.65, 0.35) # 초록 투명\n",
"col_notreat = (0.996, 0.754, 0.027, 0.35) # 노랑 투명\n",
"\n",
"# 왼쪽(0~0.5, 전체 y)\n",
"plt.gca().add_patch(Rectangle((-.1, -.1), 0.6, 1.2, facecolor=col_notreat, edgecolor='none', hatch='///'))\n",
"# 오른쪽 아래(0.5~1, 0~0.5)\n",
"plt.gca().add_patch(Rectangle((0.5, -.1), 0.6, 0.6, facecolor=col_notreat, edgecolor='none', hatch='///'))\n",
"# 오른쪽 위(0.5~1, 0.5~1)\n",
"plt.gca().add_patch(Rectangle((0.5, 0.5), 0.6, 0.6, facecolor=col_treat, edgecolor='none', hatch='///'))\n",
"\n",
"# 2) 점 찍기\n",
"plt.scatter(X[W==0,0], X[W==0,1],\n",
" c=y_norm[W==0], cmap='gray', vmin=0, vmax=1,\n",
" s=60, marker='^', edgecolors='k', linewidths=0.5,\n",
" label=\"Untreated\")\n",
"plt.scatter(X[W==1,0], X[W==1,1],\n",
" c=y_norm[W==1], cmap='gray', vmin=0, vmax=1,\n",
" s=60, marker='o', edgecolors='k', linewidths=0.5,\n",
" label=\"Treated\")\n",
"\n",
"# 3) 텍스트 라벨 붙이기\n",
"plt.text(0.75, 0.75, \"TREAT (A)\", fontsize=14, ha='center', va='center')\n",
"plt.text(0.25, 0.25, \"DO NOT TREAT (A^C)\", fontsize=12, ha='center', va='center')\n",
"\n",
"plt.xlim(-0.1, 1.1)\n",
"plt.ylim(-0.1, 1.1)\n",
"plt.xlabel(\"X1\"); plt.ylabel(\"X2\")\n",
"plt.title(\"Policy Regions with Treated vs Untreated\")\n",
"plt.legend()\n",
"plt.tight_layout()\n",
"plt.show()\n"
]
}
],
"metadata": {
Expand Down