Skip to content

Commit

Permalink
Merge pull request #19 from dobengjhu/missingdata
Browse files Browse the repository at this point in the history
Add checks for critical missing data
  • Loading branch information
dobengjhu authored Feb 11, 2025
2 parents 6441e4a + 1f95c6d commit 6a98f23
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 5 deletions.
24 changes: 23 additions & 1 deletion R/estimate_cate.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
#' @param trial_tbl tbl. A tbl containing columns for treatment, outcome, study ID, and any
#' additional covariates of interest. All study data must be included in single tbl. Note that
#' only two treatments can be considered and treatment must be coded as 0/1 (numeric).
#' Additionally, all study ID, treatment, outcome, and covariate variable values must be
#' non-missing.
#' @param estimation_method string. Single-study methods for estimating CATE (tau) for each
#' observation. Available methods are "slearner" (using Bayesian Additive Regression Trees),
#' and "causalforest".
Expand Down Expand Up @@ -144,9 +146,14 @@ estimate_cate <- function(trial_tbl,

assert_column_class(trial_tbl, treatment_col, c("numeric", "integer"))

assertthat::assert_that(
!any(is.na(trial_tbl[treatment_col])),
msg = "`treatment_col` cannot include missing values."
)

assertthat::assert_that(
length(setdiff(unique(trial_tbl[[treatment_col]]), c(0,1))) == 0,
msg = "Treatment values must be 0, 1, or NA."
msg = "Treatment values must be 0 or 1."
)

assert_column_class(trial_tbl, outcome_col, c("numeric", "integer"))
Expand All @@ -156,6 +163,21 @@ estimate_cate <- function(trial_tbl,
covariate_col <- colnames(trial_tbl)[-which(colnames(trial_tbl) %in% exclude_col)]
}

assertthat::assert_that(
!any(is.na(trial_tbl[outcome_col])),
msg = "`outcome_col` cannot include missing values."
)

assertthat::assert_that(
!any(is.na(trial_tbl[study_col])),
msg = "`study_col` cannot include missing values."
)

assertthat::assert_that(
!any(is.na(trial_tbl %>% dplyr::select(dplyr::all_of(covariate_col)))),
msg = "Variables included in `covariate_col` cannot include missing values."
)

named_args <- list(trial_tbl = trial_tbl,
study_col = study_col,
treatment_col = treatment_col,
Expand Down
4 changes: 3 additions & 1 deletion man/estimate_cate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 42 additions & 2 deletions tests/testthat/_snaps/estimate_cate.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# estimate_cate raises error for invalid treatment and/or response values
# estimate_cate raises error for invalid variable values

Code
estimate_cate(trial_tbl = dummy_tbl %>% dplyr::mutate(tx = paste0("Treatment ",
Expand All @@ -16,7 +16,7 @@
study_col = "studyid", treatment_col = "tx", outcome_col = "response")
Condition
Error:
! Treatment values must be 0, 1, or NA.
! Treatment values must be 0 or 1.

---

Expand Down Expand Up @@ -88,3 +88,43 @@
foobar

---

Code
estimate_cate(trial_tbl = dummy_tbl_study_na, estimation_method = "causalforest",
aggregation_method = "studyindicator", study_col = "studyid", treatment_col = "tx",
outcome_col = "response")
Condition
Error:
! `study_col` cannot include missing values.

---

Code
estimate_cate(trial_tbl = dummy_tbl_treatment_na, estimation_method = "causalforest",
aggregation_method = "studyindicator", study_col = "studyid", treatment_col = "tx",
outcome_col = "response")
Condition
Error:
! `treatment_col` cannot include missing values.

---

Code
estimate_cate(trial_tbl = dummy_tbl_outcome_na, estimation_method = "causalforest",
aggregation_method = "studyindicator", study_col = "studyid", treatment_col = "tx",
outcome_col = "response")
Condition
Error:
! `outcome_col` cannot include missing values.

---

Code
estimate_cate(trial_tbl = dummy_tbl_covariate_na, estimation_method = "causalforest",
aggregation_method = "studyindicator", study_col = "studyid", treatment_col = "tx",
outcome_col = "response")
Condition
Error:
! Variables included in `covariate_col` cannot include missing values.

62 changes: 61 additions & 1 deletion tests/testthat/test-estimate_cate.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
dummy_tbl_extra_var <- dummy_tbl %>%
dplyr::mutate(other = "foobar")

dummy_tbl_study_na <- dummy_tbl %>%
dplyr::mutate(studyid = ifelse(dplyr::row_number() == 10, NA, studyid))

dummy_tbl_treatment_na <- dummy_tbl %>%
dplyr::mutate(tx = ifelse(dplyr::row_number() == 10, NA, tx))

dummy_tbl_outcome_na <- dummy_tbl %>%
dplyr::mutate(response = ifelse(dplyr::row_number() == 10, NA, response))

dummy_tbl_covariate_na <- dummy_tbl %>%
dplyr::mutate(var4 = ifelse(dplyr::row_number() == 10, NA, var4))

expected_object_names <- c("estimation_method",
"aggregation_method",
"model",
Expand Down Expand Up @@ -192,7 +204,7 @@ test_that("estimate_cate returns correct structure with valid inputs (slearner /
expect_true(length(result$estimation_object) == 3)
})

test_that("estimate_cate raises error for invalid treatment and/or response values", {
test_that("estimate_cate raises error for invalid variable values", {
expect_snapshot(
estimate_cate(
trial_tbl = dummy_tbl %>% dplyr::mutate(tx = paste0("Treatment ", tx)),
Expand Down Expand Up @@ -292,4 +304,52 @@ test_that("estimate_cate raises error for missing columns", {
),
error = TRUE
)

expect_snapshot(
estimate_cate(
trial_tbl = dummy_tbl_study_na,
estimation_method = "causalforest",
aggregation_method = "studyindicator",
study_col = "studyid",
treatment_col = "tx",
outcome_col = "response"
),
error = TRUE
)

expect_snapshot(
estimate_cate(
trial_tbl = dummy_tbl_treatment_na,
estimation_method = "causalforest",
aggregation_method = "studyindicator",
study_col = "studyid",
treatment_col = "tx",
outcome_col = "response"
),
error = TRUE
)

expect_snapshot(
estimate_cate(
trial_tbl = dummy_tbl_outcome_na,
estimation_method = "causalforest",
aggregation_method = "studyindicator",
study_col = "studyid",
treatment_col = "tx",
outcome_col = "response"
),
error = TRUE
)

expect_snapshot(
estimate_cate(
trial_tbl = dummy_tbl_covariate_na,
estimation_method = "causalforest",
aggregation_method = "studyindicator",
study_col = "studyid",
treatment_col = "tx",
outcome_col = "response"
),
error = TRUE
)
})

0 comments on commit 6a98f23

Please sign in to comment.