diff --git a/NAMESPACE b/NAMESPACE index c2fa9494..62e5a6ab 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -167,7 +167,7 @@ export(flatline_forecaster) export(flusight_hub_formatter) export(forecast) export(frosting) -export(get_test_data) +export(get_predict_data) export(is_epi_recipe) export(is_epi_workflow) export(is_layer) diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index f988490f..0c588c3b 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -47,20 +47,19 @@ arx_forecaster <- function( if (!is_regression(trainer)) { cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'regression'.") } - wf <- arx_fcast_epi_workflow(epi_data, outcome, predictors, trainer, args_list) wf <- fit(wf, epi_data) # get the forecast date for the forecast function if (args_list$adjust_latency == "none") { - forecast_date_default <- max(epi_data$time_value) + reference_date_default <- max(epi_data$time_value) } else { - forecast_date_default <- attributes(epi_data)$metadata$as_of + reference_date_default <- attributes(epi_data)$metadata$as_of } - forecast_date <- args_list$forecast_date %||% forecast_date_default - + reference_date <- args_list$reference_date %||% reference_date_default + predict_interval <- args_list$predict_interval - preds <- forecast(wf, forecast_date = forecast_date) %>% + preds <- forecast(wf, reference_dates = reference_date, predict_interval = predict_interval) %>% as_tibble() %>% select(-time_value) @@ -126,21 +125,21 @@ arx_fcast_epi_workflow <- function( # if they don't and they're not adjusting latency, it defaults to the max time_value # if they're adjusting, it defaults to the as_of if (args_list$adjust_latency == "none") { - forecast_date_default <- max(epi_data$time_value) - if (!is.null(args_list$forecast_date) && args_list$forecast_date != forecast_date_default) { + reference_date_default <- max(epi_data$time_value) + if (!is.null(args_list$reference_date) && args_list$reference_date != reference_date_default) { cli_warn( - "The specified forecast date {args_list$forecast_date} doesn't match the date from which the forecast is actually occurring {forecast_date_default}.", + "The specified forecast date {args_list$reference_date} doesn't match the date from which the forecast is actually occurring {reference_date_default}.", class = "epipredict__arx_forecaster__forecast_date_defaulting" ) } } else { - forecast_date_default <- attributes(epi_data)$metadata$as_of + reference_date_default <- attributes(epi_data)$metadata$as_of } - forecast_date <- args_list$forecast_date %||% forecast_date_default - target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) - if (forecast_date + args_list$ahead != target_date) { - cli_abort("`forecast_date` {.val {forecast_date}} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.", - class = "epipredict__arx_forecaster__inconsistent_target_ahead_forecaste_date" + reference_date <- args_list$reference_date %||% reference_date_default + target_date <- args_list$target_date %||% (reference_date + args_list$ahead) + if (reference_date + args_list$ahead != target_date) { + cli_abort("`reference_date` {.val {reference_date}} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.", + class = "epipredict__arx_forecaster__inconsistent_target_ahead_forecast_date" ) } @@ -153,12 +152,12 @@ arx_fcast_epi_workflow <- function( if (!is.null(method_adjust_latency)) { if (method_adjust_latency == "extend_ahead") { r <- r %>% step_adjust_latency(all_outcomes(), - fixed_forecast_date = forecast_date, + fixed_reference_date = reference_date, method = method_adjust_latency ) } else if (method_adjust_latency == "extend_lags") { r <- r %>% step_adjust_latency(all_predictors(), - fixed_forecast_date = forecast_date, + fixed_reference_date = reference_date, method = method_adjust_latency ) } @@ -218,7 +217,7 @@ arx_fcast_epi_workflow <- function( by_key = args_list$quantile_by_key ) } - f <- layer_add_forecast_date(f, forecast_date = forecast_date) %>% + f <- layer_add_forecast_date(f, forecast_date = reference_date) %>% layer_add_target_date(target_date = target_date) if (args_list$nonneg) f <- layer_threshold(f, dplyr::starts_with(".pred")) @@ -238,19 +237,19 @@ arx_fcast_epi_workflow <- function( #' @param n_training Integer. An upper limit for the number of rows per #' key that are used for training #' (in the time unit of the `epi_df`). -#' @param forecast_date Date. The date from which the forecast is occurring. +#' @param reference_date Date. The date from which the forecast is occurring. #' The default `NULL` will determine this automatically from either #' 1. the maximum time value for which there's data if there is no latency #' adjustment (the default case), or #' 2. the `as_of` date of `epi_data` if `adjust_latency` is #' non-`NULL`. #' @param target_date Date. The date that is being forecast. The default `NULL` -#' will determine this automatically as `forecast_date + ahead`. +#' will determine this automatically as `reference_date + ahead`. #' @param adjust_latency Character. One of the `method`s of #' [step_adjust_latency()], or `"none"` (in which case there is no adjustment). -#' If the `forecast_date` is after the last day of data, this determines how +#' If the `reference_date` is after the last day of data, this determines how #' to shift the model to account for this difference. The options are: -#' - `"none"` the default, assumes the `forecast_date` is the last day of data +#' - `"none"` the default, assumes the `reference_date` is the last day of data #' - `"extend_ahead"`: increase the `ahead` by the latency so it's relative to #' the last day of data. For example, if the last day of data was 3 days ago, #' the ahead becomes `ahead+3`. @@ -280,6 +279,7 @@ arx_fcast_epi_workflow <- function( #' column names on which to group the data and check threshold within each #' group. Useful if training per group (for example, per geo_value). #' @param ... Space to handle future expansions (unused). +#' @inheritParams get_predict_data #' #' #' @return A list containing updated parameter choices with class `arx_flist`. @@ -294,7 +294,7 @@ arx_args_list <- function( lags = c(0L, 7L, 14L), ahead = 7L, n_training = Inf, - forecast_date = NULL, + reference_date = NULL, target_date = NULL, adjust_latency = c("none", "extend_ahead", "extend_lags", "locf"), warn_latency = TRUE, @@ -304,6 +304,7 @@ arx_args_list <- function( quantile_by_key = character(0L), check_enough_data_n = NULL, check_enough_data_epi_keys = NULL, + predict_interval = NULL, ...) { # error checking if lags is a list rlang::check_dots_empty() @@ -313,8 +314,8 @@ arx_args_list <- function( adjust_latency <- rlang::arg_match(adjust_latency) arg_is_scalar(ahead, n_training, symmetrize, nonneg, adjust_latency, warn_latency) arg_is_chr(quantile_by_key, allow_empty = TRUE) - arg_is_scalar(forecast_date, target_date, allow_null = TRUE) - arg_is_date(forecast_date, target_date, allow_null = TRUE) + arg_is_scalar(reference_date, target_date, allow_null = TRUE) + arg_is_date(reference_date, target_date, allow_null = TRUE) arg_is_nonneg_int(ahead, lags) arg_is_lgl(symmetrize, nonneg) arg_is_probabilities(quantile_levels, allow_null = TRUE) @@ -323,9 +324,9 @@ arx_args_list <- function( arg_is_pos(check_enough_data_n, allow_null = TRUE) arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE) - if (!is.null(forecast_date) && !is.null(target_date)) { - if (forecast_date + ahead != target_date) { - cli_abort("`forecast_date` {.val {forecast_date}} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.", + if (!is.null(reference_date) && !is.null(target_date)) { + if (reference_date + ahead != target_date) { + cli_abort("`reference_date` {.val {reference_date}} + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.", class = "epipredict__arx_args__inconsistent_target_ahead_forecaste_date" ) } @@ -338,8 +339,9 @@ arx_args_list <- function( ahead, n_training, quantile_levels, - forecast_date, + reference_date, target_date, + predict_interval, adjust_latency, warn_latency, symmetrize, diff --git a/R/cdc_baseline_forecaster.R b/R/cdc_baseline_forecaster.R index a97eece8..3c445d05 100644 --- a/R/cdc_baseline_forecaster.R +++ b/R/cdc_baseline_forecaster.R @@ -78,7 +78,7 @@ cdc_baseline_forecaster <- function( # target_date <- args_list$target_date %||% (forecast_date + args_list$ahead) - latest <- get_test_data(epi_recipe(epi_data), epi_data) + latest <- get_predict_data(epi_recipe(epi_data), epi_data) f <- frosting() %>% layer_predict() %>% diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 81b443e7..e2feb947 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -132,6 +132,9 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor #' @param new_data A data frame containing the new predictors to preprocess #' and predict on #' +#' @param reference_dates A vector matching the type of `time_value` in +#' `new_data` giving the dates of the predictions to keep. Defaults to the `reference_date` of the `object`'s recipe. +#' #' @inheritParams parsnip::predict.model_fit #' #' @return @@ -155,7 +158,7 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor #' #' preds <- predict(wf, latest) #' preds -predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), ...) { +predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), reference_dates = NULL, ...) { if (!workflows::is_trained_workflow(object)) { cli_abort(c( "Can't predict on an untrained epi_workflow.", @@ -170,7 +173,19 @@ predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), . components$keys <- grab_forged_keys(components$forged, object, new_data) components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...) - components$predictions + reference_dates <- reference_dates %||% extract_recipe(object)$reference_date + #browser() + predictions <- components$predictions %>% filter(time_value %in% reference_dates) + predictions + if (nrow(predictions) == 0) { + last_pred_date <- components$predictions %>% pull(time_value) %>% max() + last_data_date <- new_data %>% pull(time_value) %>% max() + cli_warn( + "no predictions on the reference date(s) {reference_dates}. The last prediction was on {last_pred_date}. The most recent prediction data is on {last_data_date}", + class = "epipredict__predict_epi_workflow__no_predictions" + ) + } + predictions } @@ -238,14 +253,12 @@ print.epi_workflow <- function(x, ...) { #' example, suppose n_recent = 3, then if the 3 most recent observations in any #' geo_value are all NA’s, we won’t be able to fill anything, and an error #' message will be thrown. (See details.) -#' @param forecast_date By default, this is set to the maximum time_value in x. -#' But if there is data latency such that recent NA's should be filled, this may -#' be after the last available time_value. +#' @inheritParams get_predict_data #' #' @return A forecast tibble. #' #' @export -forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date = NULL) { +forecast.epi_workflow <- function(object, ..., n_recent = NULL, reference_dates = NULL, predict_interval = NULL) { rlang::check_dots_empty() if (!object$trained) { @@ -255,6 +268,7 @@ forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date = )) } + #browser() frosting_fd <- NULL if (has_postprocessor(object) && detect_layer(object, "layer_add_forecast_date")) { frosting_fd <- extract_argument(object, "layer_add_forecast_date", "forecast_date") @@ -266,10 +280,12 @@ forecast.epi_workflow <- function(object, ..., n_recent = NULL, forecast_date = } } - test_data <- get_test_data( + predict_data <- get_predict_data( hardhat::extract_preprocessor(object), - object$original_data + object$original_data, + reference_date = reference_dates, + predict_interval = predict_interval ) - predict(object, new_data = test_data) + predict(object, new_data = predict_data, reference_dates = reference_dates) } diff --git a/R/get_predict_data.R b/R/get_predict_data.R new file mode 100644 index 00000000..466f41de --- /dev/null +++ b/R/get_predict_data.R @@ -0,0 +1,64 @@ +#' Get test data for prediction based on longest lag period +#' +#' Based on the longest lag period in the recipe, +#' `get_predict_data()` creates an [epi_df][epiprocess::as_epi_df] +#' with columns `geo_value`, `time_value` +#' and other variables in the original dataset, +#' which will be used to create features necessary to produce forecasts. +#' +#' The minimum required (recent) data to produce a forecast is equal to +#' the maximum lag requested (on any predictor) plus the longest horizon +#' used if growth rate calculations are requested by the recipe. This is +#' calculated internally. +#' +#' @param recipe A recipe object. +#' @param x An epi_df. The typical usage is to +#' pass the same data as that used for fitting the recipe. +#' @param predict_interval A time interval or integer. The length of time before +#' the `forecast_date` to consider for the forecast. The default is 1 year, +#' which you will likely only need to make longer if you are doing long +#' forecast horizons, or shorter if you are forecasting using an expensive +#' model. +#' @param reference_date By default, this is set to the maximum time_value in x. +#' But if there is data latency such that recent NA's should be filled, this may +#' be after the last available time_value. +#' +#' @return An object of the same type as `x` with columns `geo_value`, +#' `time_value`, any additional keys, as well other variables in the original +#' dataset. +#' @examples +#' # create recipe +#' rec <- epi_recipe(covid_case_death_rates) %>% +#' step_epi_ahead(death_rate, ahead = 7) %>% +#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% +#' step_epi_lag(case_rate, lag = c(0, 7, 14)) +#' get_predict_data(recipe = rec, x = covid_case_death_rates) +#' @importFrom rlang %@% +#' @importFrom stats na.omit +#' @export +get_predict_data <- function(recipe, + x, + predict_interval = NULL, + reference_date = NULL) { + if (!is_epi_df(x)) cli_abort("`x` must be an `epi_df`.") + check <- hardhat::check_column_names(x, colnames(recipe$template)) + if (!check$ok) { + cli_abort(c( + "Some variables used for training are not available in {.arg x}.", + i = "The following required columns are missing: {check$missing_names}" + )) + } + reference_date <- reference_date %||% recipe$reference_date + predict_interval <- predict_interval %||% as.difftime(365, units = "days") + trimmed_x <- x %>% + filter((reference_date - time_value) < predict_interval) + + if (nrow(trimmed_x) == 0) { + cli_abort( + "predict data is filtered to no rows; check your `predict_interval = {predict_interval}`, `reference_date= {reference_date}` and latest data {max(x$time_value)}", + class = "epipredict__get_predict_data__no_predict_data" + ) + } + + trimmed_x +} diff --git a/R/get_test_data.R b/R/get_test_data.R deleted file mode 100644 index 442272a2..00000000 --- a/R/get_test_data.R +++ /dev/null @@ -1,76 +0,0 @@ -#' Get test data for prediction based on longest lag period -#' -#' Based on the longest lag period in the recipe, -#' `get_test_data()` creates an [epi_df][epiprocess::as_epi_df] -#' with columns `geo_value`, `time_value` -#' and other variables in the original dataset, -#' which will be used to create features necessary to produce forecasts. -#' -#' The minimum required (recent) data to produce a forecast is equal to -#' the maximum lag requested (on any predictor) plus the longest horizon -#' used if growth rate calculations are requested by the recipe. This is -#' calculated internally. -#' -#' @param recipe A recipe object. -#' @param x An epi_df. The typical usage is to -#' pass the same data as that used for fitting the recipe. -#' -#' @return An object of the same type as `x` with columns `geo_value`, `time_value`, any additional -#' keys, as well other variables in the original dataset. -#' @examples -#' # create recipe -#' rec <- epi_recipe(covid_case_death_rates) %>% -#' step_epi_ahead(death_rate, ahead = 7) %>% -#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% -#' step_epi_lag(case_rate, lag = c(0, 7, 14)) -#' get_test_data(recipe = rec, x = covid_case_death_rates) -#' @importFrom rlang %@% -#' @importFrom stats na.omit -#' @export -get_test_data <- function(recipe, x) { - if (!is_epi_df(x)) cli_abort("`x` must be an `epi_df`.") - - check <- hardhat::check_column_names(x, colnames(recipe$template)) - if (!check$ok) { - cli_abort(c( - "Some variables used for training are not available in {.arg x}.", - i = "The following required columns are missing: {check$missing_names}" - )) - } - - min_lags <- min(map_dbl(recipe$steps, ~ min(.x$lag %||% Inf)), Inf) - max_lags <- max(map_dbl(recipe$steps, ~ max(.x$lag %||% 0)), 0) - max_horizon <- max(map_dbl(recipe$steps, ~ max(.x$horizon %||% 0)), 0) - max_slide <- max(map_dbl(recipe$steps, ~ max(.x$before %||% 0)), 0) - min_required <- max_lags + max_horizon + max_slide - keep <- max_lags + max_horizon - - # CHECK: Error out if insufficient training data - # Probably needs a fix based on the time_type of the epi_df - avail_recent <- diff(range(x$time_value)) - if (avail_recent < keep) { - cli_abort(c( - "You supplied insufficient recent data for this recipe. ", - "!" = "You need at least {min_required} days of data,", - "!" = "but `x` contains only {avail_recent}." - )) - } - max_time_value <- x %>% - na.omit() %>% - pull(time_value) %>% - max() - x <- arrange(x, time_value) - groups <- epi_keys_only(recipe) - - # If we skip NA completion, we remove undesirably early time values - # Happens globally, over all groups - x <- filter(x, max_time_value - time_value <= keep) - - # If all(lags > 0), then we get rid of recent data - if (min_lags > 0 && min_lags < Inf) { - x <- filter(x, max_time_value - time_value >= min_lags) - } - - filter(x, max_time_value - time_value <= keep) %>% - epiprocess::ungroup() -} diff --git a/R/tidy.R b/R/tidy.R index 3969f9dd..f1f7d347 100644 --- a/R/tidy.R +++ b/R/tidy.R @@ -35,7 +35,7 @@ #' step_epi_naomit() #' #' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) -#' latest <- get_test_data(recipe = r, x = jhu) +#' latest <- get_predict_data(recipe = r, x = jhu) #' #' f <- frosting() %>% #' layer_predict() %>% diff --git a/_pkgdown.yml b/_pkgdown.yml index 814bf6aa..b053af6d 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -84,7 +84,7 @@ reference: contents: - frosting - ends_with("_frosting") - - get_test_data + - get_predict_data - tidy.frosting - title: Frosting layers diff --git a/man/arx_args_list.Rd b/man/arx_args_list.Rd index 650c4a61..52550681 100644 --- a/man/arx_args_list.Rd +++ b/man/arx_args_list.Rd @@ -8,7 +8,7 @@ arx_args_list( lags = c(0L, 7L, 14L), ahead = 7L, n_training = Inf, - forecast_date = NULL, + reference_date = NULL, target_date = NULL, adjust_latency = c("none", "extend_ahead", "extend_lags", "locf"), warn_latency = TRUE, @@ -18,6 +18,7 @@ arx_args_list( quantile_by_key = character(0L), check_enough_data_n = NULL, check_enough_data_epi_keys = NULL, + predict_interval = NULL, ... ) } @@ -33,7 +34,7 @@ date for which forecasts should be produced.} key that are used for training (in the time unit of the \code{epi_df}).} -\item{forecast_date}{Date. The date from which the forecast is occurring. +\item{reference_date}{Date. The date from which the forecast is occurring. The default \code{NULL} will determine this automatically from either \enumerate{ \item the maximum time value for which there's data if there is no latency @@ -88,6 +89,12 @@ epi_key that are required for training. If \code{NULL}, this check is ignored.} column names on which to group the data and check threshold within each group. Useful if training per group (for example, per geo_value).} +\item{predict_interval}{A time interval or integer. The length of time before +the \code{forecast_date} to consider for the forecast. The default is 1 year, +which you will likely only need to make longer if you are doing long +forecast horizons, or shorter if you are forecasting using an expensive +model.} + \item{...}{Space to handle future expansions (unused).} } \value{ diff --git a/man/arx_class_args_list.Rd b/man/arx_class_args_list.Rd index 40bb48ca..29950645 100644 --- a/man/arx_class_args_list.Rd +++ b/man/arx_class_args_list.Rd @@ -34,15 +34,6 @@ date for which forecasts should be produced.} key that are used for training (in the time unit of the \code{epi_df}).} -\item{forecast_date}{Date. The date from which the forecast is occurring. -The default \code{NULL} will determine this automatically from either -\enumerate{ -\item the maximum time value for which there's data if there is no latency -adjustment (the default case), or -\item the \code{as_of} date of \code{epi_data} if \code{adjust_latency} is -non-\code{NULL}. -}} - \item{target_date}{Date. The date that is being forecast. The default \code{NULL} will determine this automatically as \code{forecast_date + ahead}.} diff --git a/man/cdc_baseline_args_list.Rd b/man/cdc_baseline_args_list.Rd index 4a8c1311..0aa67d9d 100644 --- a/man/cdc_baseline_args_list.Rd +++ b/man/cdc_baseline_args_list.Rd @@ -33,15 +33,6 @@ set of prediction horizons for \code{\link[=layer_cdc_flatline_quantiles]{layer_ key that are used for training (in the time unit of the \code{epi_df}).} -\item{forecast_date}{Date. The date from which the forecast is occurring. -The default \code{NULL} will determine this automatically from either -\enumerate{ -\item the maximum time value for which there's data if there is no latency -adjustment (the default case), or -\item the \code{as_of} date of \code{epi_data} if \code{adjust_latency} is -non-\code{NULL}. -}} - \item{quantile_levels}{Vector or \code{NULL}. A vector of probabilities to produce prediction intervals. These are created by computing the quantiles of training residuals. A \code{NULL} value will result in point forecasts only.} diff --git a/man/climate_args_list.Rd b/man/climate_args_list.Rd index 3a889e5c..00dbf3b5 100644 --- a/man/climate_args_list.Rd +++ b/man/climate_args_list.Rd @@ -18,15 +18,6 @@ climate_args_list( ) } \arguments{ -\item{forecast_date}{Date. The date from which the forecast is occurring. -The default \code{NULL} will determine this automatically from either -\enumerate{ -\item the maximum time value for which there's data if there is no latency -adjustment (the default case), or -\item the \code{as_of} date of \code{epi_data} if \code{adjust_latency} is -non-\code{NULL}. -}} - \item{forecast_horizon}{Vector of integers giving the number of time steps, in units of the \code{time_type}, from the \code{reference_date} for which predictions should be produced.} diff --git a/man/flatline_args_list.Rd b/man/flatline_args_list.Rd index 626bcb6f..836da4c4 100644 --- a/man/flatline_args_list.Rd +++ b/man/flatline_args_list.Rd @@ -28,15 +28,6 @@ So for example, \code{ahead = 7} will create residuals by comparing values key that are used for training (in the time unit of the \code{epi_df}).} -\item{forecast_date}{Date. The date from which the forecast is occurring. -The default \code{NULL} will determine this automatically from either -\enumerate{ -\item the maximum time value for which there's data if there is no latency -adjustment (the default case), or -\item the \code{as_of} date of \code{epi_data} if \code{adjust_latency} is -non-\code{NULL}. -}} - \item{target_date}{Date. The date that is being forecast. The default \code{NULL} will determine this automatically as \code{forecast_date + ahead}.} diff --git a/man/forecast.epi_workflow.Rd b/man/forecast.epi_workflow.Rd index 22f8cf4b..1ba51003 100644 --- a/man/forecast.epi_workflow.Rd +++ b/man/forecast.epi_workflow.Rd @@ -4,7 +4,13 @@ \alias{forecast.epi_workflow} \title{Produce a forecast from an epi workflow} \usage{ -\method{forecast}{epi_workflow}(object, ..., n_recent = NULL, forecast_date = NULL) +\method{forecast}{epi_workflow}( + object, + ..., + n_recent = NULL, + reference_date = NULL, + predict_interval = NULL +) } \arguments{ \item{object}{An epi workflow.} @@ -18,9 +24,15 @@ example, suppose n_recent = 3, then if the 3 most recent observations in any geo_value are all NA’s, we won’t be able to fill anything, and an error message will be thrown. (See details.)} -\item{forecast_date}{By default, this is set to the maximum time_value in x. +\item{reference_date}{By default, this is set to the maximum time_value in x. But if there is data latency such that recent NA's should be filled, this may be after the last available time_value.} + +\item{predict_interval}{A time interval or integer. The length of time before +the \code{forecast_date} to consider for the forecast. The default is 1 year, +which you will likely only need to make longer if you are doing long +forecast horizons, or shorter if you are forecasting using an expensive +model.} } \value{ A forecast tibble. diff --git a/man/get_test_data.Rd b/man/get_predict_data.Rd similarity index 58% rename from man/get_test_data.Rd rename to man/get_predict_data.Rd index 16359b9c..bb447d3e 100644 --- a/man/get_test_data.Rd +++ b/man/get_predict_data.Rd @@ -1,16 +1,26 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_test_data.R -\name{get_test_data} -\alias{get_test_data} +% Please edit documentation in R/get_predict_data.R +\name{get_predict_data} +\alias{get_predict_data} \title{Get test data for prediction based on longest lag period} \usage{ -get_test_data(recipe, x) +get_predict_data(recipe, x, predict_interval = NULL, reference_date = NULL) } \arguments{ \item{recipe}{A recipe object.} \item{x}{An epi_df. The typical usage is to pass the same data as that used for fitting the recipe.} + +\item{predict_interval}{A time interval or integer. The length of time before +the \code{forecast_date} to consider for the forecast. The default is 1 year, +which you will likely only need to make longer if you are doing long +forecast horizons, or shorter if you are forecasting using an expensive +model.} + +\item{reference_date}{By default, this is set to the maximum time_value in x. +But if there is data latency such that recent NA's should be filled, this may +be after the last available time_value.} } \value{ An object of the same type as \code{x} with columns \code{geo_value}, \code{time_value}, any additional @@ -18,7 +28,7 @@ keys, as well other variables in the original dataset. } \description{ Based on the longest lag period in the recipe, -\code{get_test_data()} creates an \link[epiprocess:epi_df]{epi_df} +\code{get_predict_data()} creates an \link[epiprocess:epi_df]{epi_df} with columns \code{geo_value}, \code{time_value} and other variables in the original dataset, which will be used to create features necessary to produce forecasts. @@ -35,5 +45,5 @@ rec <- epi_recipe(covid_case_death_rates) \%>\% step_epi_ahead(death_rate, ahead = 7) \%>\% step_epi_lag(death_rate, lag = c(0, 7, 14)) \%>\% step_epi_lag(case_rate, lag = c(0, 7, 14)) -get_test_data(recipe = rec, x = covid_case_death_rates) +get_predict_data(recipe = rec, x = covid_case_death_rates) } diff --git a/man/grf_quantiles.Rd b/man/grf_quantiles.Rd index 2e4b8bcb..ce40b684 100644 --- a/man/grf_quantiles.Rd +++ b/man/grf_quantiles.Rd @@ -52,8 +52,8 @@ details, see \href{https://grf-labs.github.io/grf/articles/categorical_inputs.ht #> Model fit template: #> grf::quantile_forest(X = missing_arg(), Y = missing_arg(), mtry = min_cols(~integer(1), #> x), num.trees = integer(1), min.node.size = min_rows(~integer(1), -#> x), quantiles = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95), -#> num.threads = 1L, seed = stats::runif(1, 0, .Machine$integer.max)) +#> x), quantiles = c(0.1, 0.5, 0.9), num.threads = 1L, seed = stats::runif(1, +#> 0, .Machine$integer.max)) }\if{html}{\out{}} } diff --git a/man/predict-epi_workflow.Rd b/man/predict-epi_workflow.Rd index 0b605d55..1f93ebee 100644 --- a/man/predict-epi_workflow.Rd +++ b/man/predict-epi_workflow.Rd @@ -5,7 +5,14 @@ \alias{predict.epi_workflow} \title{Predict from an epi_workflow} \usage{ -\method{predict}{epi_workflow}(object, new_data, type = NULL, opts = list(), ...) +\method{predict}{epi_workflow}( + object, + new_data, + type = NULL, + opts = list(), + reference_dates = NULL, + ... +) } \arguments{ \item{object}{An epi_workflow that has been fit by @@ -24,6 +31,9 @@ predict function that will be used when \code{type = "raw"}. The list should not include options for the model object or the new data being predicted.} +\item{reference_dates}{A vector matching the type of \code{time_value} in +\code{new_data} giving the dates of the predictions to keep. Defaults to the \code{reference_date} of the \code{object}'s recipe.} + \item{...}{Additional \code{parsnip}-related options, depending on the value of \code{type}. Arguments to the underlying model's prediction function cannot be passed here (use the \code{opts} argument instead). diff --git a/man/step_adjust_latency.Rd b/man/step_adjust_latency.Rd index 9e1bafbd..1a9bc7a3 100644 --- a/man/step_adjust_latency.Rd +++ b/man/step_adjust_latency.Rd @@ -150,8 +150,7 @@ toy_recipe \%>\% #> 4 ca 2015-01-14 103 10 #> 5 ma 2015-01-11 20 6 #> 6 ma 2015-01-12 23 6 -#> 7 ma 2015-01-13 25 6 -#> 8 ma 2015-01-14 25 6 +#> # i 2 more rows }\if{html}{\out{}} } @@ -178,19 +177,15 @@ toy_recipe \%>\% #> * as_of = 2015-01-14 #> #> # A tibble: 21 x 7 -#> geo_value time_value a b lag_3_a lag_4_b ahead_1_a -#> -#> 1 ca 2015-01-10 NA NA NA NA 100 -#> 2 ca 2015-01-11 100 5 NA NA 103 -#> 3 ca 2015-01-12 103 10 NA NA NA -#> 4 ca 2015-01-13 NA NA NA NA NA -#> 5 ca 2015-01-14 NA NA 100 NA NA -#> 6 ca 2015-01-15 NA NA 103 5 NA -#> 7 ca 2015-01-16 NA NA NA 10 NA -#> 8 ca 2015-01-17 NA NA NA NA NA -#> 9 ca 2015-01-18 NA NA NA NA NA -#> 10 ca 2015-01-19 NA NA NA NA NA -#> # i 11 more rows +#> geo_value time_value a b lag_3_a lag_4_b ahead_1_a +#> +#> 1 ca 2015-01-10 NA NA NA NA 100 +#> 2 ca 2015-01-11 100 5 NA NA 103 +#> 3 ca 2015-01-12 103 10 NA NA NA +#> 4 ca 2015-01-13 NA NA NA NA NA +#> 5 ca 2015-01-14 NA NA 100 NA NA +#> 6 ca 2015-01-15 NA NA 103 5 NA +#> # i 15 more rows }\if{html}{\out{}} The maximum latency in column \code{a} is 2 days, so the lag is increased to 3, @@ -226,18 +221,15 @@ toy_recipe \%>\% #> * as_of = 2015-01-14 #> #> # A tibble: 10 x 6 -#> geo_value time_value a b lag_0_a ahead_3_a -#> -#> 1 ca 2015-01-08 NA NA NA 100 -#> 2 ca 2015-01-09 NA NA NA 103 -#> 3 ca 2015-01-11 100 5 100 NA -#> 4 ca 2015-01-12 103 10 103 NA -#> 5 ma 2015-01-08 NA NA NA 20 -#> 6 ma 2015-01-09 NA NA NA 23 -#> 7 ma 2015-01-10 NA NA NA 25 -#> 8 ma 2015-01-11 20 6 20 NA -#> 9 ma 2015-01-12 23 NA 23 NA -#> 10 ma 2015-01-13 25 NA 25 NA +#> geo_value time_value a b lag_0_a ahead_3_a +#> +#> 1 ca 2015-01-08 NA NA NA 100 +#> 2 ca 2015-01-09 NA NA NA 103 +#> 3 ca 2015-01-11 100 5 100 NA +#> 4 ca 2015-01-12 103 10 103 NA +#> 5 ma 2015-01-08 NA NA NA 20 +#> 6 ma 2015-01-09 NA NA NA 23 +#> # i 4 more rows }\if{html}{\out{}} Even though we're doing a 1 day ahead forecast, because our worst latency diff --git a/man/tidy.frosting.Rd b/man/tidy.frosting.Rd index 3f9b0e37..7e5834be 100644 --- a/man/tidy.frosting.Rd +++ b/man/tidy.frosting.Rd @@ -46,7 +46,7 @@ r <- epi_recipe(jhu) \%>\% step_epi_naomit() wf <- epi_workflow(r, parsnip::linear_reg()) \%>\% fit(jhu) -latest <- get_test_data(recipe = r, x = jhu) +latest <- get_predict_data(recipe = r, x = jhu) f <- frosting() \%>\% layer_predict() \%>\% diff --git a/tests/testthat/_snaps/check_enough_data.md b/tests/testthat/_snaps/check_enough_data.md index 4a6ff336..2d88c77c 100644 --- a/tests/testthat/_snaps/check_enough_data.md +++ b/tests/testthat/_snaps/check_enough_data.md @@ -37,7 +37,7 @@ # check_enough_data only checks train data when skip = FALSE Code - forecaster %>% predict(new_data = toy_test_data %>% filter(time_value > + forecaster %>% predict(new_data = toy_predict_data %>% filter(time_value > "2020-01-08")) Condition Error in `check_enough_data_core()`: diff --git a/tests/testthat/_snaps/get_predict_data.md b/tests/testthat/_snaps/get_predict_data.md new file mode 100644 index 00000000..f7d96e71 --- /dev/null +++ b/tests/testthat/_snaps/get_predict_data.md @@ -0,0 +1,16 @@ +# expect insufficient training data error when the forecast date is unreasonable + + Code + get_predict_data(recipe = r, x = covid_case_death_rates) + Condition + Error in `get_predict_data()`: + ! predict data is filtered to no rows; check your `predict_interval = 365` and `reference_date= 2023-03-10` + +# expect error that geo_value or time_value does not exist + + Code + get_predict_data(recipe = r, x = wrong_epi_df) + Condition + Error in `get_predict_data()`: + ! `x` must be an `epi_df`. + diff --git a/tests/testthat/_snaps/get_test_data.md b/tests/testthat/_snaps/get_test_data.md deleted file mode 100644 index 22d0c942..00000000 --- a/tests/testthat/_snaps/get_test_data.md +++ /dev/null @@ -1,66 +0,0 @@ -# expect insufficient training data error - - Code - get_test_data(recipe = r, x = covid_case_death_rates) - Condition - Error in `get_test_data()`: - ! You supplied insufficient recent data for this recipe. - ! You need at least 367 days of data, - ! but `x` contains only 365. - -# expect error that geo_value or time_value does not exist - - Code - get_test_data(recipe = r, x = wrong_epi_df) - Condition - Error in `get_test_data()`: - ! `x` must be an `epi_df`. - -# NA fill behaves as desired - - Code - get_test_data(r, df, "A") - Condition - Error in `get_test_data()`: - ! `fill_locf` must be of type . - ---- - - Code - get_test_data(r, df, TRUE, -3) - Condition - Error in `get_test_data()`: - ! `n_recent` must be a positive integer. - ---- - - Code - get_test_data(r, df2, TRUE) - Condition - Error in `if (recipes::is_trained(recipe)) ...`: - ! argument is of length zero - -# forecast date behaves - - Code - get_test_data(r, df, TRUE, forecast_date = 9) - Condition - Error in `get_test_data()`: - ! `forecast_date` must be the same class as `x$time_value`. - ---- - - Code - get_test_data(r, df, TRUE, forecast_date = 9L) - Condition - Error in `get_test_data()`: - ! `forecast_date` must be no earlier than `max(x$time_value)` - ---- - - Code - get_test_data(r, df, forecast_date = 9L) - Condition - Error in `get_test_data()`: - ! `forecast_date` must be no earlier than `max(x$time_value)` - diff --git a/tests/testthat/test-arx_forecaster.R b/tests/testthat/test-arx_forecaster.R index d13e6d2e..5332bc05 100644 --- a/tests/testthat/test-arx_forecaster.R +++ b/tests/testthat/test-arx_forecaster.R @@ -2,24 +2,26 @@ train_data <- epidatasets::cases_deaths_subset test_that("arx_forecaster warns if forecast date beyond the implicit one", { bad_date <- max(train_data$time_value) + 300 expect_warning( - arx1 <- arx_forecaster( - train_data, - "death_rate_7d_av", - c("death_rate_7d_av", "case_rate_7d_av"), - args_list = (arx_args_list(forecast_date = bad_date)) + expect_warning( + arx1 <- arx_forecaster( + train_data, + "death_rate_7d_av", + c("death_rate_7d_av", "case_rate_7d_av"), + args_list = (arx_args_list(reference_date = bad_date)) + ), + class = "epipredict__arx_forecaster__forecast_date_defaulting" ), - class = "epipredict__arx_forecaster__forecast_date_defaulting" - ) + class = "epipredict__predict_epi_workflow__no_predictions") }) -test_that("arx_forecaster errors if forecast date, target date, and ahead are inconsistent", { +test_that("arx_forecaster errors if reference date, target date, and ahead are inconsistent", { max_date <- max(train_data$time_value) expect_error( arx1 <- arx_forecaster( train_data, "death_rate_7d_av", c("death_rate_7d_av", "case_rate_7d_av"), - args_list = (arx_args_list(ahead = 5, target_date = max_date, forecast_date = max_date)) + args_list = (arx_args_list(ahead = 5, target_date = max_date, reference_date = max_date)) ), class = "epipredict__arx_args__inconsistent_target_ahead_forecaste_date" ) @@ -36,10 +38,9 @@ test_that("warns if there's not enough data to predict", { # and actually, pretend we're around mid-October 2022: filter(time_value <= as.Date("2022-10-12")) %>% as_epi_df(as_of = as.Date("2022-10-12")) - edf %>% filter(time_value > "2022-08-01") expect_error( - edf %>% arx_forecaster("value"), - class = "epipredict__not_enough_data" + edf %>% arx_forecaster("value", args_list = arx_args_list(predict_interval = as.difftime(0, units = "days"))), + class = "epipredict__get_predict_data__no_predict_data" ) }) diff --git a/tests/testthat/test-check_enough_data.R b/tests/testthat/test-check_enough_data.R index 3ca388af..42bf58c7 100644 --- a/tests/testthat/test-check_enough_data.R +++ b/tests/testthat/test-check_enough_data.R @@ -84,7 +84,7 @@ test_that("check_enough_data outputs the correct recipe values", { test_that("check_enough_data only checks train data when skip = FALSE", { # Check that the train data has enough data, the test data does not, but # the check passes anyway (because it should be applied to training data) - toy_test_data <- toy_epi_df %>% + toy_predict_data <- toy_epi_df %>% group_by(geo_value) %>% slice(3:10) %>% epiprocess::as_epi_df() @@ -92,7 +92,7 @@ test_that("check_enough_data only checks train data when skip = FALSE", { epi_recipe(toy_epi_df) %>% check_enough_data(x, y, min_observations = n - 2, epi_keys = "geo_value") %>% prep(toy_epi_df) %>% - bake(new_data = toy_test_data) + bake(new_data = toy_predict_data) ) # Making sure `skip = TRUE` is working correctly in `predict` expect_no_error( @@ -101,7 +101,7 @@ test_that("check_enough_data only checks train data when skip = FALSE", { check_enough_data(x, min_observations = n - 2, epi_keys = "geo_value") %>% epi_workflow(linear_reg()) %>% fit(toy_epi_df) %>% - predict(new_data = toy_test_data %>% filter(time_value > "2020-01-08")) + predict(new_data = toy_predict_data %>% filter(time_value > "2020-01-08")) ) # making sure it works for skip = FALSE, where there's enough data to train # but not enough to predict @@ -115,7 +115,7 @@ test_that("check_enough_data only checks train data when skip = FALSE", { expect_snapshot( error = TRUE, forecaster %>% - predict(new_data = toy_test_data %>% filter(time_value > "2020-01-08")) + predict(new_data = toy_predict_data %>% filter(time_value > "2020-01-08")) ) }) diff --git a/tests/testthat/test-epi_workflow.R b/tests/testthat/test-epi_workflow.R index cce68a80..5da62064 100644 --- a/tests/testthat/test-epi_workflow.R +++ b/tests/testthat/test-epi_workflow.R @@ -66,14 +66,14 @@ test_that("model can be added/updated/removed from epi_workflow", { test_that("forecast method works", { jhu <- covid_case_death_rates %>% filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) - r <- epi_recipe(jhu) %>% + r <- epi_recipe(jhu, reference_date = max(jhu$time_value)) %>% step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% step_epi_ahead(death_rate, ahead = 7) %>% step_epi_naomit() wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) expect_equal( forecast(wf), - predict(wf, new_data = get_test_data( + predict(wf, new_data = get_predict_data( hardhat::extract_preprocessor(wf), jhu )) @@ -81,7 +81,7 @@ test_that("forecast method works", { expect_equal( forecast(wf), - predict(wf, new_data = get_test_data( + predict(wf, new_data = get_predict_data( hardhat::extract_preprocessor(wf), jhu )) diff --git a/tests/testthat/test-frosting.R b/tests/testthat/test-frosting.R index cd153b20..090f528d 100644 --- a/tests/testthat/test-frosting.R +++ b/tests/testthat/test-frosting.R @@ -99,7 +99,7 @@ test_that("parsnip settings can be passed through predict.epi_workflow", { wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) - latest <- get_test_data(r, jhu) + latest <- get_predict_data(r, jhu) f1 <- frosting() %>% layer_predict() f2 <- frosting() %>% layer_predict(type = "pred_int") diff --git a/tests/testthat/test-get_predict_data.R b/tests/testthat/test-get_predict_data.R new file mode 100644 index 00000000..c86c77fc --- /dev/null +++ b/tests/testthat/test-get_predict_data.R @@ -0,0 +1,56 @@ +suppressPackageStartupMessages(library(dplyr)) +forecast_date <- max(covid_case_death_rates$time_value) +test_that("return expected number of rows for various `predict_intervals`", { + r <- epi_recipe(covid_case_death_rates, reference_date = forecast_date) %>% + step_epi_ahead(death_rate, ahead = 7) %>% + step_epi_lag(death_rate, lag = c(0, 7, 14, 21, 28)) %>% + step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% + step_naomit(all_predictors()) %>% + step_naomit(all_outcomes(), skip = TRUE) + + predict_data <- get_predict_data(recipe = r, x = covid_case_death_rates) + + expect_equal( + nrow(predict_data), + dplyr::n_distinct(covid_case_death_rates$geo_value) * 365 + ) + + predict_data <- get_predict_data(recipe = r, predict_interval = 5, x = covid_case_death_rates) + + expect_equal( + nrow(predict_data), + dplyr::n_distinct(covid_case_death_rates$geo_value) * 5 + ) + + predict_data <- get_predict_data(recipe = r, predict_interval = as.difftime(35, units = "days"), x = covid_case_death_rates) + + expect_equal( + nrow(predict_data), + dplyr::n_distinct(covid_case_death_rates$geo_value) * 35 + ) +}) + + +test_that("expect insufficient training data error when the forecast date is unreasonable", { + r <- epi_recipe(covid_case_death_rates) %>% + step_epi_ahead(death_rate, ahead = 7) %>% + step_epi_lag(death_rate, lag = c(0, 367)) %>% + step_naomit(all_predictors()) %>% + step_naomit(all_outcomes(), skip = TRUE) + + expect_snapshot(error = TRUE, get_predict_data(recipe = r, x = covid_case_death_rates)) +}) + + +test_that("expect error that geo_value or time_value does not exist", { + r <- epi_recipe(covid_case_death_rates, reference_date = forecast_date) %>% + step_epi_ahead(death_rate, ahead = 7) %>% + step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% + step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% + step_naomit(all_predictors()) %>% + step_naomit(all_outcomes(), skip = TRUE) + + wrong_epi_df <- covid_case_death_rates %>% dplyr::select(-geo_value) + + expect_snapshot(error = TRUE, get_predict_data(recipe = r, x = wrong_epi_df)) +}) diff --git a/tests/testthat/test-get_test_data.R b/tests/testthat/test-get_test_data.R deleted file mode 100644 index 7822f543..00000000 --- a/tests/testthat/test-get_test_data.R +++ /dev/null @@ -1,161 +0,0 @@ -suppressPackageStartupMessages(library(dplyr)) -test_that("return expected number of rows and returned dataset is ungrouped", { - r <- epi_recipe(covid_case_death_rates) %>% - step_epi_ahead(death_rate, ahead = 7) %>% - step_epi_lag(death_rate, lag = c(0, 7, 14, 21, 28)) %>% - step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% - step_naomit(all_predictors()) %>% - step_naomit(all_outcomes(), skip = TRUE) - - test <- get_test_data(recipe = r, x = covid_case_death_rates) - - expect_equal( - nrow(test), - dplyr::n_distinct(covid_case_death_rates$geo_value) * 29 - ) - - expect_false(dplyr::is.grouped_df(test)) -}) - - -test_that("expect insufficient training data error", { - r <- epi_recipe(covid_case_death_rates) %>% - step_epi_ahead(death_rate, ahead = 7) %>% - step_epi_lag(death_rate, lag = c(0, 367)) %>% - step_naomit(all_predictors()) %>% - step_naomit(all_outcomes(), skip = TRUE) - - expect_snapshot(error = TRUE, get_test_data(recipe = r, x = covid_case_death_rates)) -}) - - -test_that("expect error that geo_value or time_value does not exist", { - r <- epi_recipe(covid_case_death_rates) %>% - step_epi_ahead(death_rate, ahead = 7) %>% - step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% - step_epi_lag(case_rate, lag = c(0, 7, 14)) %>% - step_naomit(all_predictors()) %>% - step_naomit(all_outcomes(), skip = TRUE) - - wrong_epi_df <- covid_case_death_rates %>% dplyr::select(-geo_value) - - expect_snapshot(error = TRUE, get_test_data(recipe = r, x = wrong_epi_df)) -}) - - -test_that("NA fill behaves as desired", { - testthat::skip() - df <- tibble::tibble( - geo_value = rep(c("ca", "ny"), each = 10), - time_value = rep(1:10, times = 2), - x1 = rnorm(20), - x2 = rnorm(20) - ) %>% - epiprocess::as_epi_df() - - r <- epi_recipe(df) %>% - step_epi_ahead(x1, ahead = 3) %>% - step_epi_lag(x1, x2, lag = c(1, 3)) %>% - step_epi_naomit() - - expect_silent(tt <- get_test_data(r, df)) - expect_s3_class(tt, "epi_df") - - expect_snapshot(error = TRUE, get_test_data(r, df, "A")) - expect_snapshot(error = TRUE, get_test_data(r, df, TRUE, -3)) - - df2 <- df - df2$x1[df2$geo_value == "ca"] <- NA - - td <- get_test_data(r, df2) - expect_true(any(is.na(td))) - expect_snapshot(error = TRUE, get_test_data(r, df2, TRUE)) - - df1 <- df2 - df1$x1[1:4] <- 1:4 - td1 <- get_test_data(r, df1, TRUE, n_recent = 7) - expect_true(!any(is.na(td1))) - - df2$x1[7:8] <- 1:2 - td2 <- get_test_data(r, df2, TRUE) - expect_true(!any(is.na(td2))) -}) - -test_that("forecast date behaves", { - testthat::skip() - df <- tibble::tibble( - geo_value = rep(c("ca", "ny"), each = 10), - time_value = rep(1:10, times = 2), - x1 = rnorm(20), - x2 = rnorm(20) - ) %>% - epiprocess::as_epi_df() - - r <- epi_recipe(df) %>% - step_epi_ahead(x1, ahead = 3) %>% - step_epi_lag(x1, x2, lag = c(1, 3)) - - expect_snapshot(error = TRUE, get_test_data(r, df, TRUE, forecast_date = 9)) # class error - expect_snapshot(error = TRUE, get_test_data(r, df, TRUE, forecast_date = 9L)) # fd too early - expect_snapshot(error = TRUE, get_test_data(r, df, forecast_date = 9L)) # fd too early - - ndf <- get_test_data(r, df, TRUE, forecast_date = 12L) - expect_equal(max(ndf$time_value), 11L) # max lag was 1 - expect_equal(tail(ndf$x1, 2), tail(ndf$x1, 4)[1:2]) # should have filled - - ndf <- get_test_data(r, df, FALSE, forecast_date = 12L) - expect_equal(max(ndf$time_value), 11L) - expect_equal(tail(ndf$x1, 2), as.double(c(NA, NA))) -}) - -test_that("Omit end rows according to minimum lag when that’s not lag 0", { - # Simple toy ex - - toy_epi_df <- tibble::tibble( - time_value = seq(as.Date("2020-01-01"), - by = 1, - length.out = 10 - ), - geo_value = "ak", - x = 1:10 - ) %>% epiprocess::as_epi_df() - - toy_rec <- epi_recipe(toy_epi_df) %>% - step_epi_lag(x, lag = c(2, 4)) %>% - step_epi_ahead(x, ahead = 3) %>% - step_epi_naomit() - - toy_td <- get_test_data(toy_rec, toy_epi_df) - - toy_td_res <- bake(prep(toy_rec, toy_epi_df), toy_td) - - expect_equal(ncol(toy_td_res), 6L) - expect_equal(nrow(toy_td_res), 1L) - expect_equal(toy_td_res$time_value, as.Date("2020-01-10")) - expect_equal(toy_epi_df[toy_epi_df$time_value == as.Date("2020-01-08"), ]$x, toy_td_res$lag_2_x) - expect_equal(toy_epi_df[toy_epi_df$time_value == as.Date("2020-01-06"), ]$x, toy_td_res$lag_4_x) - expect_equal(toy_td_res$x, NA_integer_) - expect_equal(toy_td_res$ahead_3_x, NA_integer_) - - # Ex. using real built-in data - - ca <- covid_case_death_rates %>% - filter(geo_value == "ca") - - rec <- epi_recipe(ca) %>% - step_epi_lag(case_rate, lag = c(2, 4, 6)) %>% - step_epi_ahead(case_rate, ahead = 7) %>% - step_epi_naomit() - - td <- get_test_data(rec, ca) - - td_res <- bake(prep(rec, ca), td) - td_row1to5_res <- bake(prep(rec, ca), td[1:5, ]) - - expect_equal(td_res, td_row1to5_res) - expect_equal(nrow(td_res), 1L) - expect_equal(td_res$time_value, as.Date("2021-12-31")) - expect_equal(ca[ca$time_value == as.Date("2021-12-29"), ]$case_rate, td_res$lag_2_case_rate) - expect_equal(ca[ca$time_value == as.Date("2021-12-27"), ]$case_rate, td_res$lag_4_case_rate) - expect_equal(ca[ca$time_value == as.Date("2021-12-25"), ]$case_rate, td_res$lag_6_case_rate) -}) diff --git a/tests/testthat/test-layer_naomit.R b/tests/testthat/test-layer_naomit.R index 8eb597f4..964c46ec 100644 --- a/tests/testthat/test-layer_naomit.R +++ b/tests/testthat/test-layer_naomit.R @@ -9,7 +9,7 @@ r <- epi_recipe(jhu) %>% wf <- epipredict::epi_workflow(r, parsnip::linear_reg()) %>% parsnip::fit(jhu) -latest <- get_test_data(recipe = r, x = jhu) %>% # 93 x 4 +latest <- get_predict_data(recipe = r, x = jhu) %>% # 93 x 4 dplyr::arrange(geo_value, time_value) latest[1:10, 4] <- NA # 10 rows have NA diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index f2efde3c..5a3023f5 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -235,7 +235,7 @@ test_that("test joining by default columns", { fit(jhu) %>% add_frosting(f) - latest <- get_test_data( + latest <- get_predict_data( recipe = r, x = covid_case_death_rates %>% dplyr::filter( @@ -415,7 +415,7 @@ test_that("test joining by default columns with less common keys/classes", { expect_warning( expect_warning( expect_equal( - # get_test_data doesn't work with non-`epi_df`s, so provide test data manually: + # get_predict_data doesn't work with non-`epi_df`s, so provide test data manually: predict(fit(ewf1b2, dat1b2), dat1b2) %>% pivot_quantiles_wider(.pred) %>% as_tibble(), @@ -583,7 +583,7 @@ test_that("test joining by default columns with less common keys/classes", { mutate(y_scaled = c(3e-6, 7e-6)) ) expect_error( - # get_test_data doesn't work with non-`epi_df`s, so provide test data manually: + # get_predict_data doesn't work with non-`epi_df`s, so provide test data manually: predict(fit(ewf4, dat4), dat4) %>% pivot_quantiles_wider(.pred), class = "epipredict__grab_forged_keys__nonunique_key" diff --git a/tests/testthat/test-step_adjust_latency.R b/tests/testthat/test-step_adjust_latency.R index 80e31dc1..edcf6dee 100644 --- a/tests/testthat/test-step_adjust_latency.R +++ b/tests/testthat/test-step_adjust_latency.R @@ -96,7 +96,7 @@ test_that("epi_adjust_latency correctly extends the lags", { "lag_6_case_rate", "lag_10_case_rate" ) ) - latest <- get_test_data(r1, real_x) + latest <- get_predict_data(r1, real_x) pred <- predict(fit1, latest) point_pred <- pred %>% filter(!is.na(.pred)) expect_equal(nrow(point_pred), 1) @@ -106,7 +106,7 @@ test_that("epi_adjust_latency correctly extends the lags", { names(fit1$pre$mold$outcomes), glue::glue("ahead_{ahead}_death_rate") ) - latest <- get_test_data(r1, x) + latest <- get_predict_data(r1, x) pred1 <- predict(fit1, latest) actual_solutions <- pred1 %>% filter(!is.na(.pred)) expect_equal(actual_solutions$time_value, testing_as_of) @@ -146,7 +146,7 @@ test_that("epi_adjust_latency correctly extends the ahead", { "lag_1_case_rate", "lag_5_case_rate" ) ) - latest <- get_test_data(r2, real_x) + latest <- get_predict_data(r2, real_x) pred2 <- predict(fit2, latest) point_pred2 <- pred2 %>% filter(!is.na(.pred)) # max time is still the forecast date @@ -262,7 +262,7 @@ test_that("epi_adjust_latency extends multiple aheads", { "lag_1_case_rate", "lag_5_case_rate" ) ) - latest <- get_test_data(r3, real_x) + latest <- get_predict_data(r3, real_x) pred3 <- predict(fit3, latest) point_pred <- pred3 %>% unnest(.pred) %>% @@ -393,7 +393,7 @@ test_that("epi_adjust_latency correctly extends the lags when there are differen "lag_7_case_rate", "lag_11_case_rate" ) ) - latest <- get_test_data(r5, x_lagged) + latest <- get_predict_data(r5, x_lagged) pred <- predict(fit5, latest) point_pred <- pred %>% filter(!is.na(.pred)) expect_equal(nrow(point_pred), 1) @@ -435,7 +435,7 @@ test_that("epi_adjust_latency correctly extends the ahead when there are differe "lag_1_case_rate", "lag_5_case_rate" ) ) - latest <- get_test_data(r5, x_lagged) + latest <- get_predict_data(r5, x_lagged) pred <- predict(fit5, latest) point_pred <- pred %>% filter(!is.na(.pred)) expect_equal(nrow(point_pred), 1) @@ -568,7 +568,7 @@ test_that("locf works as intended", { "lag_1_case_rate", "lag_5_case_rate" ) ) - latest <- get_test_data(r6, real_x) + latest <- get_predict_data(r6, real_x) pred <- predict(fit6, latest) point_pred <- pred %>% filter(!is.na(.pred)) expect_equal(max(point_pred$time_value), as.Date(testing_as_of)) @@ -577,7 +577,7 @@ test_that("locf works as intended", { names(fit6$pre$mold$outcomes), glue::glue("ahead_{ahead}_death_rate") ) - latest <- get_test_data(r6, x) + latest <- get_predict_data(r6, x) pred1 <- predict(fit6, latest) actual_solutions <- pred1 %>% filter(!is.na(.pred)) expect_equal(max(actual_solutions$time_value), testing_as_of) diff --git a/tests/testthat/test-step_climate.R b/tests/testthat/test-step_climate.R index 3dca9ec6..19a8cdb2 100644 --- a/tests/testthat/test-step_climate.R +++ b/tests/testthat/test-step_climate.R @@ -238,7 +238,7 @@ test_that("leading the climate predictor works as expected", { expect_identical(b$climate_y, expected_climate_pred) # Check if our test data has the right values - td <- get_test_data(r, x) + td <- get_predict_data(r, x) expected_test_x <- td %>% filter(time_value == "2021-12-31") %>% mutate( diff --git a/vignettes/articles/smooth-qr.Rmd b/vignettes/articles/smooth-qr.Rmd index ec07272a..5f9dbf8d 100644 --- a/vignettes/articles/smooth-qr.Rmd +++ b/vignettes/articles/smooth-qr.Rmd @@ -195,7 +195,7 @@ smooth_fc <- function(x, aheads = 1:28, degree = 3L, quantiles = 0.5, fd) { the_fit <- ewf %>% fit(x) - latest <- get_test_data(rec, x) + latest <- get_predict_data(rec, x) preds <- predict(the_fit, new_data = latest) %>% mutate(forecast_date = fd, target_date = fd + ahead) %>% diff --git a/vignettes/epipredict.Rmd b/vignettes/epipredict.Rmd index ce0a7e38..2afacbb4 100644 --- a/vignettes/epipredict.Rmd +++ b/vignettes/epipredict.Rmd @@ -424,7 +424,7 @@ ewf %>% forecast() ``` -The above `get_test_data()` function examines the recipe and ensures that enough +The above `get_predict_data()` function examines the recipe and ensures that enough test data is available to create the necessary lags and produce a prediction for the desired future time point (after the end of the training data). This mimics what would happen if `jhu` contained the most recent available historical data and diff --git a/vignettes/panel-data.Rmd b/vignettes/panel-data.Rmd index e9905789..f2e00970 100644 --- a/vignettes/panel-data.Rmd +++ b/vignettes/panel-data.Rmd @@ -252,7 +252,7 @@ data. For this demo, we will predict the number of graduates using the last 2 years of our dataset. ```{r linearreg-predict, include=T} -latest <- get_test_data(recipe = r, x = employ_small) +latest <- get_predict_data(recipe = r, x = employ_small) preds <- stats::predict(wf_linreg, latest) %>% filter(!is.na(.pred)) # Display a sample of the prediction values, excluding NAs preds %>% sample_n(5) @@ -384,7 +384,7 @@ and perform hypothesis tests as usual. Let's take a look at the predictions along with their 90% prediction intervals. ```{r} -latest <- get_test_data(recipe = rx, x = employ_small) +latest <- get_predict_data(recipe = rx, x = employ_small) predsx <- predict(wfx_linreg, latest) # Display predictions along with prediction intervals diff --git a/vignettes/preprocessing-and-models.Rmd b/vignettes/preprocessing-and-models.Rmd index 6bff4561..c9c29dbd 100644 --- a/vignettes/preprocessing-and-models.Rmd +++ b/vignettes/preprocessing-and-models.Rmd @@ -179,7 +179,7 @@ modeling and producing the prediction for death count, 7 days after the latest available date in the dataset. ```{r} -latest <- get_test_data(r, counts_subset) +latest <- get_predict_data(r, counts_subset) wf <- epi_workflow(r, parsnip::poisson_reg()) %>% fit(counts_subset)