Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function for taking expectations wrt smoothed weights #61

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
DimensionalData = "0.24"
Expand All @@ -21,6 +22,7 @@ Printf = "1.6"
RecipesBase = "1"
ReferenceTests = "0.9, 0.10"
Statistics = "1.6"
StatsBase = "0.32, 0.33, 0.34"
julia = "1.6"

[extras]
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
PSISResult
psis
ess_is
PSIS.expectation
```

## Plotting
Expand Down
2 changes: 2 additions & 0 deletions src/PSIS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module PSIS
using LogExpFunctions: LogExpFunctions
using Printf: @sprintf
using Statistics: Statistics
using StatsBase: StatsBase

export PSISPlots
export PSISResult
Expand All @@ -12,6 +13,7 @@ include("utils.jl")
include("generalized_pareto.jl")
include("core.jl")
include("ess.jl")
include("expectation.jl")
include("recipes/plots.jl")

end
73 changes: 73 additions & 0 deletions src/expectation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
expectation(x, psis_result::PSISResult; kind=Statistics.mean)

Compute the expectation of `x` with respect to the weights in `psis_result`.

# Arguments

- `x`: An array of values of shape `(draws[, chains[, params...]])`, to compute the
expectation of with respect to smoothed importance weights.
- `psis_result`: A `PSISResult` object containing the smoothed importance weights with shape
`(draws[, chains, params...])`.

# Keywords

- `kind=Statistics.mean`: The type of expectation to be computed. It can be any function
that has a method for computing the weighted expectation
`f(x::AbstractVector, weights::AbstractVector) -> Real`. In particular, the following
are supported:

+ `Statistics.mean`
+ `Statistics.median`
+ `Statistics.std`
+ `Statistics.var`
+ `Base.Fix2(Statistics.quantile, p::Real)` for `quantile(x, weights, p)`

# Returns

- `values`: An array of shape `(other..., params...)` or real number of `other` and `params`
are empty containing the expectation of `x` with respect to the smoothed importance
weights.
"""
function expectation(x::AbstractArray, psis_result::PSISResult; kind=Statistics.mean)
log_weights = psis_result.log_weights
weights = psis_result.weights

Check warning on line 34 in src/expectation.jl

View check run for this annotation

Codecov / codecov/patch

src/expectation.jl#L32-L34

Added lines #L32 - L34 were not covered by tests

param_dims = _param_dims(log_weights)
exp_dims = _param_dims(x)
if !isempty(exp_dims) && length(exp_dims) != length(param_dims)
throw(

Check warning on line 39 in src/expectation.jl

View check run for this annotation

Codecov / codecov/patch

src/expectation.jl#L36-L39

Added lines #L36 - L39 were not covered by tests
ArgumentError(
"The trailing dimensions of `x` must match the parameter dimensions of `psis_result.weights`",
),
)
end
param_axes = map(Base.Fix1(axes, log_weights), param_dims)
exp_axes = map(Base.Fix1(axes, x), exp_dims)
if !isempty(exp_axes) && exp_axes != param_axes
throw(

Check warning on line 48 in src/expectation.jl

View check run for this annotation

Codecov / codecov/patch

src/expectation.jl#L45-L48

Added lines #L45 - L48 were not covered by tests
ArgumentError(
"The trailing axes of `x` must match the parameter axes of `psis_result.weights`",
),
)
end

T = Base.promote_eltype(x, log_weights)
values = similar(x, T, param_axes)

Check warning on line 56 in src/expectation.jl

View check run for this annotation

Codecov / codecov/patch

src/expectation.jl#L55-L56

Added lines #L55 - L56 were not covered by tests

for i in _eachparamindex(weights)
w_i = StatsBase.AnalyticWeights(vec(_selectparam(weights, i)), 1)
x_i = vec(ndims(x) < 3 ? x : _selectparam(x, i))
values[i] = _expectation(kind, x_i, w_i)
end

Check warning on line 62 in src/expectation.jl

View check run for this annotation

Codecov / codecov/patch

src/expectation.jl#L58-L62

Added lines #L58 - L62 were not covered by tests

iszero(ndims(values)) && return values[]

Check warning on line 64 in src/expectation.jl

View check run for this annotation

Codecov / codecov/patch

src/expectation.jl#L64

Added line #L64 was not covered by tests

return values

Check warning on line 66 in src/expectation.jl

View check run for this annotation

Codecov / codecov/patch

src/expectation.jl#L66

Added line #L66 was not covered by tests
end

_expectation(f, x, weights) = f(x, weights)
function _expectation(f::Base.Fix2{typeof(Statistics.quantile),<:Real}, x, weights)
prob = f.x
return Statistics.quantile(x, weights, prob)

Check warning on line 72 in src/expectation.jl

View check run for this annotation

Codecov / codecov/patch

src/expectation.jl#L69-L72

Added lines #L69 - L72 were not covered by tests
end
Loading