Skip to content

Refactored scikit-learn flavour of DifferenceInDifferences and allowed custom column names for post_treatment variable. #515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

roesta07
Copy link

@roesta07 roesta07 commented Jul 30, 2025

closes issues #390 and #514

  • causal impact calculation in scikit-learn flavour of DifferenceInDifferences
  • Allow the user to use whatever column name they want for 'post_treatment' variable while constructing DifferenceInDifferences object with new parameter post_treatment_variable_name . Also setting its default value to 'post_treatment' so that it does not break previously written codes.

📚 Documentation preview 📚: https://causalpy--515.org.readthedocs.build/en/515/

Copy link

codecov bot commented Jul 30, 2025

Codecov Report

❌ Patch coverage is 88.88889% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.13%. Comparing base (09adfd7) to head (7fbb27a).

Files with missing lines Patch % Lines
causalpy/experiments/diff_in_diff.py 88.88% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #515      +/-   ##
==========================================
- Coverage   95.19%   95.13%   -0.06%     
==========================================
  Files          28       28              
  Lines        2457     2468      +11     
==========================================
+ Hits         2339     2348       +9     
- Misses        118      120       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@drbenvincent drbenvincent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Looks like the remote checks are failing. Sometimes you need to run the pre-commit checks locally twice - the interrogate thing is a bit fiddly.
  • And looks like we'll need to increase test coverage. So obvious ones would be to include tests where we use the default, or a user-provided post treatment variable name.

Overall, this is looking good. Thanks for the PR :)

Oh, remember to update from main regularly :)

)
# Check if post_treatment_variable_name is in formula
if self.post_treatment_variable_name not in self.formula:
if self.post_treatment_variable_name == "post_treatment":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got a minor preference to just give one generic exception message, rather than a custom one dependent on self.post_treatment_variable_name. That will also cut down on the number of tests required to achieve high test coverage.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah absolutely!! More generic ones like "Missing required variable '{self.post_treatment_variable_name}' in formula" can be used


# Check if post_treatment_variable_name is in data columns
if self.post_treatment_variable_name not in self.data.columns:
if self.post_treatment_variable_name == "post_treatment":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above. Just give one more generic exception message, regardless of what self.post_treatment_variable_name is.

# Store the coefficient into dictionary {intercept:value}
coef_map = dict(zip(self.labels, self.model.get_coeffs()))
# Create and find the interaction term based on the values user provided
interaction_term = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. We'll need more tests anyway to ensure test coverage, so when you do that can you add cases for when people specify formulas like post_treatment:a and post_treatment*b. It should work because we'll always get a coefficient for post_treatment:a, but it is worth adding the test

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, will add some tests for a cases where a user provides post treatment variable name and check for FormulaExeption and DataException

but @drbenvincent can you elaborate on this specific test. Are we also checking the coefficient value where two interaction terms are used?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd not thought of that. I guess it's easy to find and interaction term of the post treatment variable and something else. But if there are two interaction terms, both including the post treatment variable, then that might get messy. Can we think of any situations where that be a good idea? If not, then maybe that could throw and exception and we just say we can't deal with a formula like that?

@@ -128,6 +130,12 @@ def __init__(
}
self.model.fit(X=self.X, y=self.y, coords=COORDS)
elif isinstance(self.model, RegressorMixin):
# For scikit-learn models, automatically set fit_intercept=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants