Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Collate:
'distribution_R6_class.R'
'distribution_continuous.R'
'distribution_discrete.R'
'distribution_mixture.R'
'mastiff-package.R'
'mixture_of_two_normals.R'
'plot_posterior.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export(check_numeric)
export(distribution.binomial)
export(distribution.exponential)
export(distribution.gamma)
export(distribution.mixture)
export(distribution.negative_binomial)
export(distribution.normal)
export(distribution.point_mass)
Expand Down
154 changes: 154 additions & 0 deletions R/distribution_mixture.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Include R6_util_class.R and distribution_R6_class.R to guarantee base classes exist
# when loading the package prior to defining classes

################################################################################/
# distribution.mixture.class
################################################################################/
#' Class: `distribution.mixture.class`
#' @description Class to describe the mixture of distributions
#'
#' @param q vector of quantiles.
#' @param p vector of probabilities.
#' @param log.p logical; if TRUE, probabilities p are given as `log(p)`.
#' @param lower.tail logical; if TRUE (default), probabilities are \eqn{P[ X \leq x ]},
#' otherwise, \eqn{P[X>x]}.
#'
#' @include R6_class.R
#' @include distribution_R6_class.R
distribution.mixture.class <- R6.class(
classname = "distribution.mixture.class",
inherit = distribution.continuous.class,
private = list(
.distributions = NULL,
.n_distributions = NULL,
.weights = NULL
),
public = list(
############################################################################/
# initialize
############################################################################/
#' @description Create a new object of class `distribution.mixture.class`
initialize = function( distributions, weights ){
stopifnot( length( distributions) == length( weights ) )
stopifnot( all( unlist( lapply( distributions, function( d ) is.distribution( d ) ) ) ) )
stopifnot( all( unlist( lapply( distributions, function( d ) inherits( d, "distribution.continuous.class" ) ) ) ) )
stopifnot( abs( sum( weights ) - 1 ) < 1e-10 )

private$.distributions <- distributions
private$.weights <- weights / sum( weights )
private$.n_distributions <- length( distributions )
private$.support <- c( min( unlist( lapply( distributions, function( d ) d$support[1] ))),
max( unlist( lapply( distributions, function( d ) d$support[2] ))))
},
############################################################################/
# density
############################################################################/
#' @description Density function for a random variable of the mixture
d = function( x, log = FALSE ){
ds <- matrix( unlist( lapply( private$.distributions,
function( d ) d$d( x ) ) ),
ncol = self$n_distributions )
return( ( ds %*% self$weights )[, 1 ] )
},
##############################################################################/
# cumulative distribution function
##############################################################################/
#' @description Evaluates the distribution function of the mixture
p = function( q, lower.tail = TRUE, log.p = FALSE ){
ps <- matrix( unlist( lapply( private$.distributions,
function( d ) d$p( q, lower.tail = lower.tail, log.p = log.p ) ) ),
ncol = self$n_distributions )
return( ( ps %*% self$weights )[, 1 ] )
},
##############################################################################/
# quantile function
##############################################################################/
#' @description Evaluates the quantile function of the mixture
q = function( p, lower.tail = TRUE, log.p = FALSE ){
super$q( p, lower.tail = lower.tail, log.p = log.p )
# qs <- matrix( unlist( lapply( private$.distributions,
# function( d ) d$q( p, lower.tail = lower.tail, log.p = log.p ) ) ),
# ncol = self$n_distributions )
# return( ( qs %*% self$weights )[, 1 ] )
},
############################################################################/
# random deviates
############################################################################/
#' @description Generates random samples of the mixture
r = function( n ){
dists <- self$distributions
n_dist <- length( dists )
model <- sample( 1:n_dist, n, replace = TRUE, prob = self$weights )

ret <- vector( mode = "numeric", length = n_dist )
for( ddx in 1:n_dist ) {
idxs <- which( model == ddx )
if( length( idxs) > 0 )
ret[ idxs ] <- dists[[ddx]]$r( length( idxs ) )
}
return( ret )
}
),
active = list(
support = function( val ){
private$.staticReturn( val, "support" )
},
############################################################################/
# n_distributions
############################################################################/
#' @description Number of distributions in the mixture
n_distributions = function( val ){
private$.staticReturn( val, "n_distributions" )
},
############################################################################/
# distributions
############################################################################/
#' @description The distributions of the mixture
distributions = function( val ){
if( missing( val ) )
return( private$.staticReturn( val, "distributions" ) )

# allow updates of values on distributions objects
old_val <- private$.distributions
stopifnot( length( val ) == length( old_val ) )
for( idx in 1:length( val ) ) {
stopifnot( is.distribution( val[[ idx ]] ) )
stopifnot( data.table::address( val[[ idx ]] ) ==
data.table::address( old_val[[ idx ]] ) )
}

},
############################################################################/
# weights
############################################################################/
#' @description The weights of each distribution in the mixture
weights = function( new = NA ){
if( length( new) == 1 )
if( is.na( new ) ) {
return( private$.weights )
}

stopifnot( length( new ) == self$n_distributions )
stopifnot( abs( 1 - sum( new ) ) < 1e-10 )
private$.weights <- new / sum( new )
}
)
)

#' distribution.mixture
#'
#' Constructor function for an object of class `distribution.mixture.class`
#'
#' @param distributions a list of distributions (all continuous)
#' @param weights the mixture weight of each distribution (sum to 1)
#'
#' @returns An object of class [[distribution.mixture.class]]
#'
#' @seealso [Mastiff-Distributions]
#' @export
distribution.mixture <- function( distributions, weights ){
distribution.mixture.class$new( distributions, weights )
}



22 changes: 22 additions & 0 deletions man/distribution.mixture.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

114 changes: 114 additions & 0 deletions man/distribution.mixture.class.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

90 changes: 90 additions & 0 deletions tests/testthat/test-distribution_mixutre.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
library( mastiff )
library( testthat )

test_that( "Test mixture distribution", {

all_dists <- list(
list(
distribution.gamma( 2,1 ),
distribution.gamma( 2,2 )
),
list(
distribution.normal( 0,1 ),
distribution.normal( 1,2 )
)
)
all_params <- list(
list(
list( shape = 1, rate = 2 ),
list( shape = 3, rate = 1 )
),
list(
list( mean = 1, sd = 2 ),
list( mean = -1, sd = 0.5 )
)
)

for( ddx in 1:length( all_dists ) ) {
dists <- all_dists[[ ddx ]]
new_params <- all_params[[ ddx ]]

weights <- rep( 1, length( dists ) ) / length( dists )
expect_no_error( mix <- distribution.mixture( dists, weights ) )
expect_equal( mix$n_distributions, length( weights ) )
expect_equal( mix$weights, weights )
expect_equal( mix$support, dists[[1]]$support )

for( idx in 1:length( weights ) ) {
# check distributions as intialised and updatable
expect_equal( data.table::address( mix$distributions[[idx]] ),
data.table::address( dists[[idx]] ) )
expect_no_error( mix$distributions[[idx]]$params <- new_params[[ idx ]] )
for( name in names( new_params[[ idx ]] ) )
expect_equal( mix$distributions[[idx]]$params[[name]], new_params[[ idx ]][[name]] )

# check weighting just one returns the correct distribution functions
weights <- rep( 0, mix$n_distributions )
weights[idx] <- 1
expect_no_error( mix$weights <- weights )
expect_equal( mix$weights, weights )
expect_equal( mix$d( c(1,2) ), mix$distributions[[idx]]$d( c(1,2) ) )
expect_equal( mix$p( c(0.25,0.75) ), mix$distributions[[idx]]$p( c(0.25,0.75) ) )
expect_equal( mix$q( c(0.25,0.75) ), mix$distributions[[idx]]$q( c(0.25,0.75) ) )
}

# check the distribution of random draws from mixture agree with CDF
weights <- seq( 1:mix$n_distributions )
mix$weights <- weights / sum( weights )
n_samples <- 1e4
xs <- mix$r( n_samples )

# get the quantiles of the sample
probs <- seq( 0.1, 0.9, 0.1)
rqs <- quantile( xs, probs )

# compare to calculate CDF at these points (samples in a quantile are binomially distributed )
ps <- mix$p( rqs )
expect_lt( max( abs( ps - probs) / sqrt( probs * ( 1 - probs ) / n_samples ) ), 4 )

# check within range of the quantule of these points
qmin <- mix$q( probs - 4 * sqrt( probs * ( 1 - probs ) / n_samples ))
qmax <- mix$q( probs + 4 * sqrt( probs * ( 1 - probs ) / n_samples ))
expect_equal( sum( rqs < qmax ), length( probs ) )
expect_equal( sum( rqs > qmin ), length( probs ) )
}
} )


test_that( "Check error on invalid update of mixture distribution", {
dists <- list(
distribution.gamma( 2,1),
distribution.gamma( 2,2)
)
weights <- c( 0.5, 0.5)
expect_no_error( mix <- distribution.mixture( dists, weights ) )
nondist <- R6.class( "non-dist")
expect_error( distribution.mixture( dists, 1 ) )
expect_error( distribution.mixture( c( dists,nondist ), c( weights, 1 ) ) )
expect_error( mix$weights <- 1 )
} )

Loading