Skip to content

Commit a8b4264

Browse files
Sync process. WIP. ISSUE: ppc_calibratrion loses the posterior mean.
1 parent 14eb2dc commit a8b4264

File tree

2 files changed

+494
-70
lines changed

2 files changed

+494
-70
lines changed

R/ppc-calibration.R

Lines changed: 237 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,37 @@
1-
# x' PPC calibration
1+
#' PPC calibration
22
#'
3-
#' Assess the calibration of the predictive distributions `yrep` in relation to
4-
#' the data `y'.
3+
#' Assess the calibration of the predictions, or predictive probabilites in relation to
4+
#' binary observations.
55
#' See the **Plot Descriptions** section, below, for details.
66
#'
77
#' @name PPC-calibration
88
#' @family PPCs
99
#'
1010
#' @template args-y-yrep
1111
#' @template args-group
12+
#' @param interval_type For `ppc_calibration()`, `ppc_calibration_grouped()`,
13+
#' 'ppc_loo_calibration()', and ´ppc_loo_calibration_grouped()´, the type of
14+
#' interval to compute. Options are '"consistency"' (default) for credible
15+
#' intervals for the PAV-adjusted calibration curve of posterior predictive
16+
#' sample, or `"confidence"` for the credible intervals of the calibration
17+
#' curve of the observed binary events.
1218
#'
1319
#' @template return-ggplot-or-data
1420
#'
1521
#' @section Plot Descriptions:
1622
#' \describe{
1723
#' \item{`ppc_calibration()`,`ppc_calibration_grouped()`}{
18-
#'
19-
#' },
24+
#' PAV-adjusted calibration plots showing the relationship between the
25+
#' predicted event probabilities and the conditional event probabilities.
26+
#' The `interval_type` parameter controls whether confidence intervals, or
27+
#' consistency intervals are computed.},
2028
#' \item{`ppc_calibration_overlay()`,`ppc_calibration_overlay_grouped()`}{
21-
#'
29+
#' Overlay plots showing posterior samples of PAV-adjusted calibration
30+
#' curves.
2231
#' },
2332
#' \item{`ppc_loo_calibration()`,`ppc_loo_calibration_grouped()`}{
24-
#'
33+
#' PAV-adjusted calibration plots to assess the calibration of the
34+
#' leave-one-out (LOO) predictive probabilities.
2535
#' }
2636
#' }
2737
#'
@@ -35,6 +45,10 @@
3545
#' prep <- (example_yrep_draws() - ymin) / (ymax - ymin)
3646
#'
3747
#' ppc_calibration_overlay(y, prep[1:50, ])
48+
#'
49+
#' # Compare confidence vs consistency intervals
50+
#' ppc_calibration(y, prep, interval_type = "confidence")
51+
#' ppc_calibration(y, prep, interval_type = "consistency")
3852
NULL
3953

4054

@@ -64,7 +78,7 @@ ppc_calibration_overlay <- function(
6478
ppc_calibration_overlay_grouped <- function(
6579
y, prep, group, ..., linewidth = 0.25, alpha = 0.7) {
6680
check_ignored_arguments(...)
67-
data <- .ppc_calibration_data(y, prep, group)
81+
data <- .ppc_calibration_data(y, prep = prep, group = group)
6882
ggplot(data) +
6983
geom_abline(color = "black", linetype = 2) +
7084
geom_line(aes(value, cep, group = rep_id, color = "yrep"),
@@ -83,15 +97,17 @@ ppc_calibration_overlay_grouped <- function(
8397
#' @rdname PPC-calibration
8498
#' @export
8599
ppc_calibration <- function(
86-
y, prep, prob = .95, show_mean = TRUE, ..., linewidth = 0.5, alpha = 0.7) {
100+
y, prep = NULL, yrep = NULL, prob = .95, interval_type = c("confidence", "consistency"), ...,
101+
linewidth = 0.5, alpha = 0.7) {
87102
check_ignored_arguments(...)
88-
data <- .ppc_calibration_data(y, prep) %>%
89-
group_by(y_id) %>%
103+
interval_type <- match.arg(interval_type)
104+
data <- .ppc_calibration_data(y, prep, yrep, NULL, interval_type) %>%
105+
group_by(idx) %>%
90106
summarise(
91-
value = median(value),
92-
lb = quantile(cep, .5 - .5 * prob),
93-
ub = quantile(cep, .5 + .5 * prob),
94-
cep = median(cep)
107+
value = mean(value),
108+
lb = quantile(cep_intervals, .5 - .5 * prob),
109+
ub = quantile(cep_intervals, .5 + .5 * prob),
110+
cep = mean(cep)
95111
)
96112

97113
ggplot(data) +
@@ -115,16 +131,26 @@ ppc_calibration <- function(
115131
#' @rdname PPC-calibration
116132
#' @export
117133
ppc_calibration_grouped <- function(
118-
y, prep, group, prob = .95, show_mean = TRUE, ..., linewidth = 0.5, alpha = 0.7) {
134+
y,
135+
prep = NULL,
136+
yrep = NULL,
137+
group,
138+
prob = .95,
139+
interval_type = c("confidence", "consistency"),
140+
...,
141+
linewidth = 0.5,
142+
alpha = 0.7) {
119143
check_ignored_arguments(...)
120-
data <- .ppc_calibration_data(y, prep, group) %>%
121-
group_by(group, y_id) %>%
144+
interval_type <- match.arg(interval_type)
145+
data <- .ppc_calibration_data(y, prep, yrep, group, interval_type) %>%
146+
group_by(group, idx) %>%
122147
summarise(
123-
value = median(value),
124-
lb = quantile(cep, .5 - .5 * prob),
125-
ub = quantile(cep, .5 + .5 * prob),
126-
cep = median(cep)
127-
)
148+
value = mean(value),
149+
lb = quantile(cep_intervals, .5 - .5 * prob),
150+
ub = quantile(cep_intervals, .5 + .5 * prob),
151+
cep = mean(cep),
152+
) %>%
153+
ungroup()
128154

129155
ggplot(data) +
130156
aes(value, cep) +
@@ -148,77 +174,218 @@ ppc_calibration_grouped <- function(
148174
#' @rdname PPC-calibration
149175
#' @export
150176
ppc_loo_calibration <- function(
151-
y, prep, lw, ..., linewidth = 0.25, alpha = 0.7) {
177+
y, yrep, lw, prob = .95, interval_type = c("confidence", "consistency"), ..., linewidth = 0.5, alpha = 0.7) {
152178
check_ignored_arguments(...)
153-
data <- .ppc_calibration_data(y, prep)
154-
ggplot(data) +
155-
geom_abline(color = "black", linetype = 2) +
156-
geom_line(
157-
aes(value, cep, group = rep_id, color = "yrep"),
158-
linewidth = linewidth, alpha = alpha
159-
) +
160-
scale_color_ppc() +
161-
bayesplot_theme_get() +
162-
legend_none() +
163-
coord_equal(xlim = c(0, 1), ylim = c(0, 1), expand = FALSE) +
164-
xlab("Predicted probability") +
165-
ylab("Conditional event probability") +
166-
NULL
179+
# Create LOO-predictive samples using resampling
180+
yrep_resampled <- .loo_resample_data(yrep, lw, psis_object = NULL)
181+
ppc_calibration(y, yrep = yrep_resampled, prob = prob, interval_type = interval_type, ..., linewidth = linewidth, alpha = alpha)
167182
}
168183

169184
#' @rdname PPC-calibration
170185
#' @export
171186
ppc_loo_calibration_grouped <- function(
172-
y, prep, group, lw, ..., linewidth = 0.25, alpha = 0.7) {
187+
y,
188+
yrep,
189+
group,
190+
lw,
191+
prob = .95,
192+
interval_type = c("confidence", "consistency"),
193+
...,
194+
linewidth = 0.5,
195+
alpha = 0.7) {
173196
check_ignored_arguments(...)
174-
data <- .ppc_calibration_data(y, prep, group)
175-
ggplot(data) +
176-
geom_abline(color = "black", linetype = 2) +
177-
geom_line(aes(value, cep, group = rep_id, color = "yrep"),
178-
linewidth = linewidth, alpha = alpha
179-
) +
180-
facet_wrap(vars(group)) +
181-
scale_color_ppc() +
182-
bayesplot_theme_get() +
183-
legend_none() +
184-
coord_equal(xlim = c(0, 1), ylim = c(0, 1), expand = FALSE) +
185-
xlab("Predicted probability") +
186-
ylab("Conditional event probability") +
187-
NULL
197+
# Create LOO-predictive samples using resampling
198+
yrep_resampled <- .loo_resample_data(yrep, lw, psis_object = NULL)
199+
ppc_calibration_grouped(
200+
y,
201+
yrep = yrep_resampled,
202+
group = group,
203+
prob = prob,
204+
interval_type = interval_type,
205+
...,
206+
linewidth = linewidth,
207+
alpha = alpha
208+
)
188209
}
189210

190-
.ppc_calibration_data <- function(y, prep, group = NULL) {
211+
.ppc_calibration_data <- function(
212+
y,
213+
prep = NULL,
214+
yrep = NULL,
215+
group = NULL,
216+
interval_type = c("confidence", "consistency")) {
191217
y <- validate_y(y)
192218
n_obs <- length(y)
193-
prep <- validate_predictions(prep, n_obs)
194-
if (any(prep > 1 | prep < 0)) {
195-
stop("Values of ´prep´ should be predictive probabilities between 0 and 1.")
219+
interval_type <- match.arg(interval_type)
220+
221+
# Determine if we're using prep (probabilities) or yrep (predictive samples)
222+
if (!is.null(prep)) {
223+
predictions <- validate_predictions(prep, n_obs)
224+
if (any(prep > 1 | prep < 0)) {
225+
stop(
226+
"Values of ´prep´ should be predictive probabilities between 0 and 1."
227+
)
228+
}
229+
is_probability <- TRUE
230+
} else if (!is.null(yrep)) {
231+
predictions <- validate_predictions(yrep, n_obs)
232+
is_probability <- FALSE
233+
} else {
234+
stop("Either 'prep' or 'yrep' must be provided.")
196235
}
236+
197237
if (!is.null(group)) {
198238
group <- validate_group(group, n_obs)
199239
} else {
200-
group <- rep(1, n_obs * nrow(prep))
240+
group <- rep(1, n_obs)
241+
}
242+
243+
data <- .ppd_data(predictions, group = group)
244+
245+
if (interval_type == "confidence") {
246+
if (is_probability) {
247+
# confidence interval from predicted probabilities:
248+
# cep = cep_intervals = monotone(y[order(ppred)])
249+
# i.e. posterior of the calibration curve trajectory
250+
# data %>%
251+
# group_by(group, rep_id) %>%
252+
# mutate(
253+
# ord = order(value),
254+
# value = value[ord],
255+
# cep_intervals = .monotone(y[ord]),
256+
# cep = cep_intervals,
257+
# idx = seq_len(n())
258+
# ) %>%
259+
data %>%
260+
group_by(group) %>%
261+
group_modify(.f = pava_transform(.x, .y, y, NULL)) %>%
262+
ungroup() %>%
263+
select(value, cep_intervals, cep, idx, group)
264+
} else {
265+
ppred <- colMeans(predictions)
266+
data %>%
267+
group_by(group, rep_id) %>%
268+
mutate(
269+
idx_boot = sample(y_id, n(), replace = TRUE),
270+
ord = order(ppred[idx_boot]),
271+
value = ppred[idx_boot][ord],
272+
cep_intervals = .monotone(y[idx_boot][ord]),
273+
cep = cep_intervals,
274+
idx = seq_len(n()),
275+
) %>%
276+
ungroup() %>%
277+
select(value, cep_intervals, cep, idx, group)
278+
}
279+
} else {
280+
# Consistency intervals
281+
if (is_probability) {
282+
# For prep (probabilities), generate predictive samples from binomial
283+
data %>%
284+
group_by(group, rep_id) %>%
285+
mutate(
286+
# Generate predictive samples from binomial distribution
287+
ord = order(value),
288+
value = value[ord],
289+
cep_intervals = .monotone(rbinom(n(), 1, value)),
290+
cep = .monotone(y[ord]),
291+
idx = seq_len(n())
292+
) %>%
293+
ungroup() %>%
294+
select(value, cep, cep_intervals, idx, group)
295+
} else {
296+
# For yrep (predictive samples), use column averages for ordering
297+
ppred <- colMeans(predictions)
298+
data %>%
299+
group_by(group, rep_id) %>%
300+
mutate(
301+
# Use column averages for ordering when yrep is provided
302+
ord = order(ppred[y_id]),
303+
cep_intervals = .monotone(value[ord]),
304+
value = ppred[y_id][ord],
305+
cep = .monotone(y[y_id][ord]),
306+
idx = seq_len(n()),
307+
) %>%
308+
ungroup() %>%
309+
select(value, cep_intervals, cep, idx, group)
310+
}
311+
}
312+
}
313+
314+
.loo_resample_data <- function(yrep, lw, psis_object) {
315+
lw <- .get_lw(lw, psis_object)
316+
stopifnot(identical(dim(yrep), dim(lw)))
317+
318+
# Resample each column (observation) with its corresponding weights
319+
n_obs <- ncol(yrep)
320+
n_draws <- nrow(yrep)
321+
322+
# Initialize resampled matrix
323+
yrep_resampled <- matrix(NA, nrow = n_draws, ncol = n_obs)
324+
325+
for (i in 1:n_obs) {
326+
# Create draws object for this observation
327+
draws_i <- posterior::as_draws_matrix(yrep[, i, drop = FALSE])
328+
329+
# Resample using the weights for this observation
330+
weights_i <- lw[, i]
331+
resampled_i <- posterior::resample_draws(
332+
draws_i,
333+
weights = weights_i, ndraws = n_draws
334+
)
335+
336+
# Extract the resampled values
337+
yrep_resampled[, i] <- as.numeric(resampled_i)
338+
}
339+
340+
# Add observation names if available
341+
if (!is.null(colnames(yrep))) {
342+
colnames(yrep_resampled) <- colnames(yrep)
201343
}
202344

345+
yrep_resampled
346+
}
347+
348+
.monotone <- function(y) {
203349
if (requireNamespace("monotone", quietly = TRUE)) {
204350
monotone <- monotone::monotone
205351
} else {
206352
monotone <- function(y) {
207353
stats::isoreg(y)$yf
208354
}
209355
}
210-
.ppd_data(prep, group = group) %>%
211-
group_by(group, rep_id) %>%
212-
mutate(
213-
ord = order(value),
214-
value = value[ord],
215-
cep = monotone(y[ord])
216-
) |>
217-
ungroup()
356+
monotone(y)
218357
}
219358

220-
.loo_resample_data <- function(prep, lw, psis_object) {
221-
lw <- .get_lw(lw, psis_object)
222-
stopifnot(identical(dim(prep), dim(lw)))
223-
# Work in progress here...
359+
pava_transform <- function(.x, .y, y, yrep, interval_type) {
360+
if (no_prob) {
361+
probs <- .x %>%
362+
group_by(y_id) %>%
363+
summarise(p = mean(yrep)) %>%
364+
arrange(y_id) %>%
365+
pull(p)
366+
ord <- order(probs)
367+
data <- .x |>
368+
group_by(rep_id) |>
369+
mutate(ord = ord) |>
370+
ungroup()
371+
} else {
372+
data <- data |>
373+
group_by(rep_id) |>
374+
mutate(ord = order(value)) |>
375+
ungroup()
376+
}
377+
if (interval_type == "confidence") {
378+
data %>%
379+
group_by(rep_id) %>%
380+
mutate(
381+
cep_intervals = .monotone(y[y_id][ord_v]),
382+
value = ifelse(is.null(yrep), value[ord_v], probs[ord],
383+
cep = cep_interval,
384+
idx = seq_len(n()),
385+
) %>%
386+
ungroup() %>%
387+
select(value, cep_intervals, cep, idx, rep_id, y_id)
388+
} else {
389+
data
390+
}
224391
}

0 commit comments

Comments
 (0)