Skip to content

Commit 0189130

Browse files
committed
wip: add yeo-johnson
1 parent 7cd135f commit 0189130

File tree

3 files changed

+541
-0
lines changed

3 files changed

+541
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#' Unormalizing transformation
2+
#'
3+
#' Will undo a step_YeoJohnson2 transformation.
4+
#'
5+
#' @param frosting a `frosting` postprocessor. The layer will be added to the
6+
#' sequence of operations for this frosting.
7+
#' @param ... One or more selector functions to scale variables
8+
#' for this step. See [recipes::selections()] for more details.
9+
#' @param df a data frame that contains the population data to be used for
10+
#' inverting the existing scaling.
11+
#' @param by A (possibly named) character vector of variables to join by.
12+
#' @param id a random id string
13+
#'
14+
#' @return an updated `frosting` postprocessor
15+
#' @export
16+
#' @examples
17+
#' library(dplyr)
18+
#' jhu <- epidatasets::cases_deaths_subset %>%
19+
#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
20+
#' select(geo_value, time_value, cases)
21+
#'
22+
#' pop_data <- data.frame(states = c("ca", "ny"), value = c(20000, 30000))
23+
#'
24+
#' r <- epi_recipe(jhu) %>%
25+
#' step_YeoJohnson2(
26+
#' df = pop_data,
27+
#' df_pop_col = "value",
28+
#' by = c("geo_value" = "states"),
29+
#' cases, suffix = "_scaled"
30+
#' ) %>%
31+
#' step_epi_lag(cases_scaled, lag = c(0, 7, 14)) %>%
32+
#' step_epi_ahead(cases_scaled, ahead = 7, role = "outcome") %>%
33+
#' step_epi_naomit()
34+
#'
35+
#' f <- frosting() %>%
36+
#' layer_predict() %>%
37+
#' layer_threshold(.pred) %>%
38+
#' layer_naomit(.pred) %>%
39+
#' layer_YeoJohnson2(.pred,
40+
#' df = pop_data,
41+
#' by = c("geo_value" = "states"),
42+
#' df_pop_col = "value"
43+
#' )
44+
#'
45+
#' wf <- epi_workflow(r, linear_reg()) %>%
46+
#' fit(jhu) %>%
47+
#' add_frosting(f)
48+
#'
49+
#' forecast(wf)
50+
layer_YeoJohnson2 <- function(frosting,
51+
...,
52+
df,
53+
by = NULL,
54+
id = rand_id("YeoJohnson2")) {
55+
arg_is_scalar(df_pop_col, rate_rescaling, create_new, suffix, id)
56+
arg_is_lgl(create_new)
57+
arg_is_chr(df_pop_col, suffix, id)
58+
arg_is_chr(by, allow_null = TRUE)
59+
if (rate_rescaling <= 0) {
60+
cli_abort("`rate_rescaling` must be a positive number.")
61+
}
62+
63+
add_layer(
64+
frosting,
65+
layer_YeoJohnson2_new(
66+
df = df,
67+
by = by,
68+
df_pop_col = df_pop_col,
69+
rate_rescaling = rate_rescaling,
70+
terms = dplyr::enquos(...),
71+
create_new = create_new,
72+
suffix = suffix,
73+
id = id
74+
)
75+
)
76+
}
77+
78+
layer_YeoJohnson2_new <-
79+
function(df, by, df_pop_col, rate_rescaling, terms, create_new, suffix, id) {
80+
layer("YeoJohnson2",
81+
df = df,
82+
by = by,
83+
df_pop_col = df_pop_col,
84+
rate_rescaling = rate_rescaling,
85+
terms = terms,
86+
create_new = create_new,
87+
suffix = suffix,
88+
id = id
89+
)
90+
}
91+
92+
#' @export
93+
slather.layer_YeoJohnson2 <-
94+
function(object, components, workflow, new_data, ...) {
95+
rlang::check_dots_empty()
96+
97+
browser()
98+
if (is.null(object$by)) {
99+
# Assume `layer_predict` has calculated the prediction keys and other
100+
# layers don't change the prediction key colnames:
101+
prediction_key_colnames <- names(components$keys)
102+
lhs_potential_keys <- prediction_key_colnames
103+
rhs_potential_keys <- colnames(select(object$df, !object$df_pop_col))
104+
object$by <- intersect(lhs_potential_keys, rhs_potential_keys)
105+
suggested_min_keys <- kill_time_value(lhs_potential_keys)
106+
if (!all(suggested_min_keys %in% object$by)) {
107+
cli_warn(c(
108+
"{setdiff(suggested_min_keys, object$by)} {?was an/were} epikey column{?s} in the predictions,
109+
but {?wasn't/weren't} found in the population `df`.",
110+
"i" = "Defaulting to join by {object$by}",
111+
">" = "Double-check whether column names on the population `df` match those expected in your predictions",
112+
">" = "Consider using population data with breakdowns by {suggested_min_keys}",
113+
">" = "Manually specify `by =` to silence"
114+
), class = "epipredict__layer_YeoJohnson2__default_by_missing_suggested_keys")
115+
}
116+
}
117+
118+
object$by <- object$by %||% intersect(
119+
epi_keys_only(components$predictions),
120+
colnames(select(object$df, !object$df_pop_col))
121+
)
122+
joinby <- list(x = names(object$by) %||% object$by, y = object$by)
123+
hardhat::validate_column_names(components$predictions, joinby$x)
124+
hardhat::validate_column_names(object$df, joinby$y)
125+
126+
# object$df <- object$df %>%
127+
# dplyr::mutate(dplyr::across(tidyselect::where(is.character), tolower))
128+
pop_col <- rlang::sym(object$df_pop_col)
129+
exprs <- rlang::expr(c(!!!object$terms))
130+
pos <- tidyselect::eval_select(exprs, components$predictions)
131+
col_names <- names(pos)
132+
suffix <- ifelse(object$create_new, object$suffix, "")
133+
col_to_remove <- setdiff(colnames(object$df), colnames(components$predictions))
134+
135+
components$predictions <- inner_join(
136+
components$predictions,
137+
object$df,
138+
by = object$by,
139+
relationship = "many-to-one",
140+
unmatched = c("error", "drop"),
141+
suffix = c("", ".df")
142+
) %>%
143+
mutate(across(
144+
all_of(col_names),
145+
~ .x * !!pop_col / object$rate_rescaling,
146+
.names = "{.col}{suffix}"
147+
)) %>%
148+
select(-any_of(col_to_remove))
149+
components
150+
}
151+
152+
#' @export
153+
print.layer_YeoJohnson2 <- function(
154+
x, width = max(20, options()$width - 30), ...) {
155+
title <- "Scaling predictions by population"
156+
print_layer(x$terms, title = title, width = width)
157+
}

0 commit comments

Comments
 (0)