Skip to content

Commit 98d921d

Browse files
committed
Add structured nuisance-model quality diagnostics
1 parent 1a3b54f commit 98d921d

8 files changed

Lines changed: 384 additions & 1 deletion

File tree

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ pip install -e .
5656

5757
В `compare_policies(...).to_dict()` и `diagnostics` возвращаются `propensity_source` и `propensity_column` (если применимо).
5858

59+
### Nuisance model diagnostics
60+
61+
В high-level summary добавлен блок `nuisance_diagnostics`:
62+
- behavior-side quality для estimated propensity path (например multiclass log-loss, top-1 agreement);
63+
- outcome-side quality (`accept`: log-loss/Brier/AUC, `cltv`: RMSE/MAE/R²);
64+
- маркеры `applicable` и `is_out_of_fold`, чтобы явно различать logged path и cross-fit OOF path.
65+
66+
Важно: diagnostics по весам/overlap и diagnostics качества nuisance дополняют друг друга; ни один из блоков сам по себе не гарантирует корректность оценки на реальных данных.
67+
5968

6069
## Simulation validation harness (synthetic oracle checks)
6170

docs/architecture.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,17 @@ Harness поддерживает сравнение методов (`replay`, `i
121121
Назначение — внутренний validation/regression инструмент для развития библиотеки. Это не универсальная гарантия корректности на произвольных real-world логах.
122122

123123
Подробнее: `docs/validation_harness.md`.
124+
125+
126+
## 10) Nuisance-model quality diagnostics
127+
128+
Дополнительно к overlap/weight diagnostics введён отдельный слой `nuisance diagnostics` для качества моделей `pi_hat` и `mu_hat`.
129+
130+
Зачем: хорошие ESS/overlap сами по себе не гарантируют, что nuisance-модели адекватны. Поэтому в structured outputs добавляются behavior/outcome quality метрики и warning rules.
131+
132+
Ключевые принципы:
133+
- logged propensity path: behavior model quality помечается как not applicable;
134+
- estimated propensity path: behavior quality считается и добавляется в summary;
135+
- cross-fit mode: diagnostics отмечаются как OOF (fold-aware provenance).
136+
137+
Этот слой не меняет формулы estimators и служит для trust-quality интерпретации результатов.

docs/validation_harness.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
- `Delta_CI` coverage (если CI рассчитан);
1212
- частота significance decision (`is_significant`);
1313
- diagnostics-поля (например, `weight_ess_ratio`, `weight_p99`);
14-
- provenance (`propensity_source_used`, `propensity_column_used`).
14+
- provenance (`propensity_source_used`, `propensity_column_used`);
15+
- nuisance-quality summaries (например behavior log-loss, outcome log-loss/RMSE) для сравнения режимов.
1516

1617
На уровне aggregate (по `mode` и `estimator`):
1718
- mean bias, std, RMSE для `V_B` и `delta`;

src/policyscope/comparison.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from policyscope.diagnostics import compute_policy_diagnostics, PolicyDiagnostics
1313
from policyscope.inference import infer_policy_comparison_bootstrap
1414
from policyscope.estimators import value_on_policy
15+
from policyscope.nuisance_diagnostics import NuisanceDiagnostics, compute_nuisance_diagnostics
1516
from policyscope.nuisance import (
1617
CrossFitNuisanceBundle,
1718
PropensitySource,
@@ -48,6 +49,7 @@ class PolicyComparisonSummary:
4849
notes: tuple[str, ...] = field(default_factory=tuple)
4950
propensity_source: Optional[str] = None
5051
propensity_column: Optional[str] = None
52+
nuisance_diagnostics: Optional[NuisanceDiagnostics] = None
5153

5254
def to_dict(self) -> dict:
5355
out = {
@@ -83,6 +85,8 @@ def to_dict(self) -> dict:
8385
out["propensity_source"] = self.propensity_source
8486
if self.propensity_column is not None:
8587
out["propensity_column"] = self.propensity_column
88+
if self.nuisance_diagnostics is not None:
89+
out["nuisance_diagnostics"] = self.nuisance_diagnostics.to_dict()
8690
return out
8791

8892

@@ -226,6 +230,19 @@ def point_on(part: pd.DataFrame) -> float:
226230
propensity_col=propensity_col,
227231
)
228232

233+
nuisance_diag = compute_nuisance_diagnostics(
234+
df,
235+
target=target,
236+
estimator=estimator,
237+
feature_cols=feature_cols,
238+
action_col=action_col,
239+
propensity_source=diag.propensity_source or resolved_source,
240+
behavior_predictions=(
241+
nuisance_bundle.behavior if nuisance_bundle is not None and nuisance_bundle.behavior is not None else resolved_behavior
242+
),
243+
nuisance_bundle=nuisance_bundle,
244+
)
245+
229246
if not with_ci:
230247
return PolicyComparisonSummary(
231248
estimator=estimator,
@@ -237,6 +254,7 @@ def point_on(part: pd.DataFrame) -> float:
237254
notes=propensity_notes + tuple(diag.warnings),
238255
propensity_source=diag.propensity_source or resolved_source,
239256
propensity_column=diag.propensity_column or resolved_propensity_col,
257+
nuisance_diagnostics=nuisance_diag,
240258
)
241259

242260
def estimator_pair(part: pd.DataFrame):
@@ -272,6 +290,7 @@ def estimator_pair(part: pd.DataFrame):
272290
notes=notes,
273291
propensity_source=diag.propensity_source or resolved_source,
274292
propensity_column=diag.propensity_column or resolved_propensity_col,
293+
nuisance_diagnostics=nuisance_diag,
275294
)
276295

277296

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
"""Nuisance-model quality diagnostics for behavior and outcome models."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass, field
6+
from typing import Optional, Sequence
7+
8+
import numpy as np
9+
import pandas as pd
10+
from sklearn.metrics import brier_score_loss, log_loss, mean_absolute_error, mean_squared_error, r2_score, roc_auc_score
11+
12+
from policyscope.estimators import mu_hat_predict
13+
from policyscope.nuisance import (
14+
BehaviorPredictions,
15+
CrossFitNuisanceBundle,
16+
OutcomePredictions,
17+
fit_outcome_nuisance_bundle,
18+
)
19+
20+
21+
@dataclass(frozen=True)
22+
class BehaviorModelDiagnostics:
23+
applicable: bool
24+
propensity_source: str
25+
is_out_of_fold: bool
26+
multiclass_log_loss: Optional[float] = None
27+
top1_accuracy: Optional[float] = None
28+
mean_logged_action_prob: Optional[float] = None
29+
warnings: tuple[str, ...] = field(default_factory=tuple)
30+
31+
def to_dict(self) -> dict:
32+
return {
33+
"applicable": self.applicable,
34+
"propensity_source": self.propensity_source,
35+
"is_out_of_fold": self.is_out_of_fold,
36+
"multiclass_log_loss": self.multiclass_log_loss,
37+
"top1_accuracy": self.top1_accuracy,
38+
"mean_logged_action_prob": self.mean_logged_action_prob,
39+
"warnings": list(self.warnings),
40+
}
41+
42+
43+
@dataclass(frozen=True)
44+
class OutcomeModelDiagnostics:
45+
applicable: bool
46+
target: str
47+
is_binary_target: bool
48+
is_out_of_fold: bool
49+
log_loss: Optional[float] = None
50+
brier_score: Optional[float] = None
51+
roc_auc: Optional[float] = None
52+
rmse: Optional[float] = None
53+
mae: Optional[float] = None
54+
r2: Optional[float] = None
55+
warnings: tuple[str, ...] = field(default_factory=tuple)
56+
57+
def to_dict(self) -> dict:
58+
return {
59+
"applicable": self.applicable,
60+
"target": self.target,
61+
"is_binary_target": self.is_binary_target,
62+
"is_out_of_fold": self.is_out_of_fold,
63+
"log_loss": self.log_loss,
64+
"brier_score": self.brier_score,
65+
"roc_auc": self.roc_auc,
66+
"rmse": self.rmse,
67+
"mae": self.mae,
68+
"r2": self.r2,
69+
"warnings": list(self.warnings),
70+
}
71+
72+
73+
@dataclass(frozen=True)
74+
class NuisanceDiagnostics:
75+
behavior: BehaviorModelDiagnostics
76+
outcome: OutcomeModelDiagnostics
77+
warnings: tuple[str, ...] = field(default_factory=tuple)
78+
79+
def to_dict(self) -> dict:
80+
return {
81+
"behavior": self.behavior.to_dict(),
82+
"outcome": self.outcome.to_dict(),
83+
"warnings": list(self.warnings),
84+
}
85+
86+
87+
def _compute_behavior_diagnostics(
88+
df: pd.DataFrame,
89+
*,
90+
action_col: str,
91+
behavior_predictions: Optional[BehaviorPredictions],
92+
propensity_source: Optional[str],
93+
) -> BehaviorModelDiagnostics:
94+
if behavior_predictions is None or propensity_source not in {"estimated", "auto"}:
95+
return BehaviorModelDiagnostics(
96+
applicable=False,
97+
propensity_source=propensity_source or "unknown",
98+
is_out_of_fold=False,
99+
warnings=("behavior_model_not_applicable_for_logged_propensity",),
100+
)
101+
102+
y = df[action_col].to_numpy()
103+
p_taken = np.clip(behavior_predictions.pA_taken, 1e-12, 1.0)
104+
ll = float(-np.mean(np.log(p_taken)))
105+
top1 = None
106+
if behavior_predictions.pA_all is not None:
107+
top1 = float(np.mean(np.argmax(behavior_predictions.pA_all, axis=1) == y))
108+
warnings: list[str] = []
109+
if ll > 1.2:
110+
warnings.append("weak_behavior_log_loss")
111+
if top1 is not None and top1 < 0.4:
112+
warnings.append("weak_behavior_top1_accuracy")
113+
114+
return BehaviorModelDiagnostics(
115+
applicable=True,
116+
propensity_source=propensity_source or behavior_predictions.propensity_source or "estimated",
117+
is_out_of_fold=bool(behavior_predictions.is_out_of_fold),
118+
multiclass_log_loss=ll,
119+
top1_accuracy=top1,
120+
mean_logged_action_prob=float(np.mean(p_taken)),
121+
warnings=tuple(warnings),
122+
)
123+
124+
125+
def _compute_outcome_diagnostics(
126+
df: pd.DataFrame,
127+
*,
128+
target: str,
129+
feature_cols: Optional[Sequence[str]],
130+
action_col: str,
131+
estimator: str,
132+
outcome_predictions: Optional[OutcomePredictions],
133+
) -> OutcomeModelDiagnostics:
134+
if estimator not in {"dm", "dr", "sndr", "switch_dr"}:
135+
return OutcomeModelDiagnostics(
136+
applicable=False,
137+
target=target,
138+
is_binary_target=False,
139+
is_out_of_fold=False,
140+
warnings=("outcome_model_not_used_for_estimator",),
141+
)
142+
143+
y = df[target].to_numpy()
144+
is_binary = np.array_equal(np.unique(y), np.array([0, 1])) or np.array_equal(np.unique(y), np.array([0.0, 1.0]))
145+
146+
if outcome_predictions is None:
147+
mu_bundle = fit_outcome_nuisance_bundle(df, target=target, feature_cols=feature_cols, action_col=action_col)
148+
pred = mu_hat_predict(mu_bundle.mu_model, df, df[action_col].to_numpy(), target)
149+
is_oof = False
150+
else:
151+
pred = outcome_predictions.mu_logged_action
152+
is_oof = bool(outcome_predictions.is_out_of_fold)
153+
154+
warnings: list[str] = []
155+
if is_binary:
156+
p = np.clip(pred, 1e-12, 1 - 1e-12)
157+
ll = float(log_loss(y, p, labels=[0, 1]))
158+
br = float(brier_score_loss(y, p))
159+
try:
160+
auc = float(roc_auc_score(y, p))
161+
except ValueError:
162+
auc = None
163+
if ll > 0.69:
164+
warnings.append("weak_outcome_log_loss")
165+
if br > 0.25:
166+
warnings.append("weak_outcome_brier")
167+
if auc is not None and auc < 0.6:
168+
warnings.append("weak_outcome_auc")
169+
return OutcomeModelDiagnostics(
170+
applicable=True,
171+
target=target,
172+
is_binary_target=True,
173+
is_out_of_fold=is_oof,
174+
log_loss=ll,
175+
brier_score=br,
176+
roc_auc=auc,
177+
warnings=tuple(warnings),
178+
)
179+
180+
rmse = float(np.sqrt(mean_squared_error(y, pred)))
181+
mae = float(mean_absolute_error(y, pred))
182+
r2 = float(r2_score(y, pred))
183+
if r2 < 0.0:
184+
warnings.append("weak_outcome_r2")
185+
return OutcomeModelDiagnostics(
186+
applicable=True,
187+
target=target,
188+
is_binary_target=False,
189+
is_out_of_fold=is_oof,
190+
rmse=rmse,
191+
mae=mae,
192+
r2=r2,
193+
warnings=tuple(warnings),
194+
)
195+
196+
197+
def compute_nuisance_diagnostics(
198+
df: pd.DataFrame,
199+
*,
200+
target: str,
201+
estimator: str,
202+
feature_cols: Optional[Sequence[str]],
203+
action_col: str,
204+
propensity_source: Optional[str],
205+
behavior_predictions: Optional[BehaviorPredictions] = None,
206+
nuisance_bundle: Optional[CrossFitNuisanceBundle] = None,
207+
) -> NuisanceDiagnostics:
208+
"""Compute structured nuisance quality diagnostics for official outputs."""
209+
if behavior_predictions is None and nuisance_bundle is not None:
210+
behavior_predictions = nuisance_bundle.behavior
211+
outcome_predictions = nuisance_bundle.outcome if nuisance_bundle is not None else None
212+
213+
behavior = _compute_behavior_diagnostics(
214+
df,
215+
action_col=action_col,
216+
behavior_predictions=behavior_predictions,
217+
propensity_source=propensity_source,
218+
)
219+
outcome = _compute_outcome_diagnostics(
220+
df,
221+
target=target,
222+
feature_cols=feature_cols,
223+
action_col=action_col,
224+
estimator=estimator,
225+
outcome_predictions=outcome_predictions,
226+
)
227+
warnings = tuple(list(behavior.warnings) + list(outcome.warnings))
228+
return NuisanceDiagnostics(behavior=behavior, outcome=outcome, warnings=warnings)

src/policyscope/validation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class ValidationRunRow:
4646
propensity_column_used: Optional[str]
4747
ess_ratio: Optional[float]
4848
weight_p99: Optional[float]
49+
behavior_log_loss: Optional[float]
50+
outcome_log_loss: Optional[float]
51+
outcome_rmse: Optional[float]
4952

5053

5154
@dataclass(frozen=True)
@@ -90,6 +93,9 @@ def _aggregate_rows(rows: list[ValidationRunRow]) -> pd.DataFrame:
9093
significance_rate=("is_significant", "mean"),
9194
mean_ess_ratio=("ess_ratio", "mean"),
9295
mean_weight_p99=("weight_p99", "mean"),
96+
mean_behavior_log_loss=("behavior_log_loss", "mean"),
97+
mean_outcome_log_loss=("outcome_log_loss", "mean"),
98+
mean_outcome_rmse=("outcome_rmse", "mean"),
9399
)
94100
.reset_index()
95101
)
@@ -187,6 +193,21 @@ def run_simulation_validation(
187193
propensity_column_used=summary.propensity_column,
188194
ess_ratio=diag.get("weight_ess_ratio"),
189195
weight_p99=diag.get("weight_p99"),
196+
behavior_log_loss=(
197+
summary.nuisance_diagnostics.behavior.multiclass_log_loss
198+
if summary.nuisance_diagnostics is not None
199+
else None
200+
),
201+
outcome_log_loss=(
202+
summary.nuisance_diagnostics.outcome.log_loss
203+
if summary.nuisance_diagnostics is not None
204+
else None
205+
),
206+
outcome_rmse=(
207+
summary.nuisance_diagnostics.outcome.rmse
208+
if summary.nuisance_diagnostics is not None
209+
else None
210+
),
190211
)
191212
)
192213

tests/test_docs_consistency.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def test_architecture_doc_exists_and_mentions_domain_model():
1515
assert "compare_policies" in text
1616
assert 'propensity_source="auto"' in text
1717
assert "logged vs estimated propensity" in text.lower()
18+
assert "nuisance-model quality diagnostics" in text.lower()
1819

1920

2021
def test_readme_mentions_p_value_method():
@@ -24,6 +25,7 @@ def test_readme_mentions_p_value_method():
2425
assert "weight_ess_ratio" in text
2526
assert "compare_policies_multi_target" in text
2627
assert "propensity source modes" in text
28+
assert "nuisance model diagnostics" in text
2729
assert 'propensity_source="auto"' in text
2830

2931

0 commit comments

Comments
 (0)