Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions SDeMo/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## `v1.8.1`

- **added** a `ceterisparibus` explanation function for CP-plots
- **added** explanatory plots for trained models

## `v1.8.0`

- **added** an overload of `copy` for SDMs
Expand Down
2 changes: 1 addition & 1 deletion SDeMo/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "SDeMo"
uuid = "3e5feb82-bcca-434d-9cd5-c11731a21467"
version = "1.8.0"
version = "1.8.1"
authors = ["Timothée Poisot <[email protected]>"]

[deps]
Expand Down
170 changes: 169 additions & 1 deletion SDeMo/ext/MakieExtension.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,179 @@
module MakieExtension

import Makie
using Makie
using Statistics
using SDeMo
import SDeMo:
cpplot, cpplot!, iceplot, iceplot!, partialdependenceplot, partialdependenceplot!

function Makie.convert_arguments(P::Makie.PointBased, model::T) where {T <: AbstractSDM}
@assert isgeoreferenced(model)
return Makie.convert_arguments(P, model.coordinates)
end

function _shared_argument_cp_plots()
return Makie.@DocumentedAttributes begin
bins = 30
color = @inherit color :black
linestyle = @inherit linestyle :solid
linewidth = @inherit linewidth 1
alpha = @inherit alpha 1.0
stairs = true
center = :none # Alt values :midpoint, :value
colormap = @inherit colormap :viridis
colorrange = @inherit colorrange (0, 1)
end
end

# CP plot

function _check_trained(model)
if !istrained(model)
throw(UntrainedModelError())
end
return nothing
end

function _cpplot_data(model, inst, feat, bins)
_check_trained(model)
X = instance(model, inst; strict = false)
x = collect(LinRange(extrema(features(model, feat))..., bins))
Y = permutedims(repeat(X', length(x)))
Y[feat, :] .= x
y = predict(model, Y; threshold = false)
return (x, y)
end

function _cpplot_data(model, inst, f1, f2, bins)
_check_trained(model)
x1 = collect(LinRange(extrema(features(model, f1))..., bins))
x2 = collect(LinRange(extrema(features(model, f2))..., bins))
Y = zeros(bins, bins)
X = instance(model, inst; strict = false)
for i in eachindex(x1)
X[f1] = x1[i]
for j in eachindex(x2)
X[f2] = x2[j]
Y[i,j] = predict(model, X; threshold=false)
end
end
return (x1, x2, Y)
end

Makie.@recipe CPPlot (sdm, instance, feature) begin
_shared_argument_cp_plots()...
end

const OneDimCP = CPPlot{<:Tuple{AbstractSDM, Integer, Integer}}
const TwoDimCP = CPPlot{<:Tuple{AbstractSDM, <:Integer, <:Integer, <:Integer}}

Makie.convert_arguments(::Type{OneDimCP}, sdm::AbstractSDM, inst::Integer, feat::Integer) = (sdm, inst, feat)
Makie.convert_arguments(::Type{TwoDimCP}, sdm::AbstractSDM, inst::Integer, f1::Integer, f2::Integer) = (sdm, inst, [f1, f2])

function Makie.plot!(cp::OneDimCP)
x, y = _cpplot_data(cp.sdm[], cp.instance[], cp.feature[], cp.bins[])
if cp.center[] == :midpoint
x .-= (x[end] + x[begin]) / 2
end
if cp.center[] == :value
x .-= instance(cp.sdm[], cp.instance[]; strict = false)[cp.feature[]]
end
plfunc! = cp.stairs[] ? stairs! : lines!
plfunc!(cp, cp.attributes, x, y)
return cp
end

function Makie.plot!(cp::TwoDimCP)
model = cp.arg1[]
inst = cp.arg2[]
f1 = cp.arg3[]
f2 = cp.arg4[]
x1, x2, y = _cpplot_data(model, inst, f1, f2, cp.bins[])
if cp.center[] == :midpoint
x1 .-= (x1[end] + x1[begin]) / 2
x2 .-= (x2[end] + x2[begin]) / 2
end
if cp.center[] == :value
x1 .-= instance(model, inst; strict = false)[f1]
x2 .-= instance(model, inst; strict = false)[f2]
end
heatmap!(cp, cp.attributes, x1, x2, y)
return cp
end

# ICE plot

Makie.@recipe ICEPlot (sdm, instances, feature) begin
_shared_argument_cp_plots()...
end

Makie.plottype(::ICEPlot) = Lines
Makie.convert_arguments(::Type{ICEPlot}, sdm::AbstractSDM, x::Colon, y::Int) =
(sdm, eachindex(labels(sdm)), y)
Makie.convert_arguments(::Type{ICEPlot}, sdm::AbstractSDM, x::AbstractRange, y::Int) =
(sdm, collect(x), y)
Makie.convert_arguments(::Type{ICEPlot}, sdm::AbstractSDM, x::Vector{Int}, y::Int) =
(sdm, x, y)
Makie.convert_arguments(::Type{ICEPlot}, sdm::AbstractSDM, y::Int) =
(sdm, eachindex(labels(sdm)), y)

function Makie.plot!(ice::ICEPlot)
for i in ice.instances[]
cpplot!(ice, ice.attributes, ice.sdm[], i, ice.feature[])
end
return ice
end

# PCP plot

Makie.@recipe PartialDependencePlot (sdm, instances, feature) begin
_shared_argument_cp_plots()...
ribbon = nothing # Or a function for the ribbon
background = :grey80
end

Makie.plottype(::PartialDependencePlot) = Lines
Makie.convert_arguments(::Type{PartialDependencePlot}, sdm::AbstractSDM, x::Colon, y::Int) =
(sdm, eachindex(labels(sdm)), y)
Makie.convert_arguments(
::Type{PartialDependencePlot},
sdm::AbstractSDM,
x::AbstractRange,
y::Int,
) = (sdm, collect(x), y)
Makie.convert_arguments(
::Type{PartialDependencePlot},
sdm::AbstractSDM,
x::Vector{Int},
y::Int,
) = (sdm, x, y)
Makie.convert_arguments(
::Type{PartialDependencePlot},
sdm::AbstractSDM,
y::Int,
) = (sdm, eachindex(labels(sdm)), y)

function Makie.plot!(pdp::PartialDependencePlot)
x, _ = _cpplot_data(pdp.sdm[], 1, pdp.feature[], pdp.bins[])
Y = zeros(Float64, pdp.bins[], length(pdp.instances[]))
for (i, inst) in enumerate(pdp.instances[])
xᵢ, yᵢ = _cpplot_data(pdp.sdm[], inst, pdp.feature[], pdp.bins[])
Y[:, i] = yᵢ
end
if pdp.center[] == :midpoint
x .-= (x[end] + x[begin]) / 2
end
if pdp.center[] == :value
@warn "Using center = :value has no effect on a partial dependence plot"
end
μ = dropdims(mapslices(Statistics.mean, Y; dims = 2); dims = 2)
if !isnothing(pdp.ribbon[])
r = dropdims(mapslices(pdp.ribbon[], Y; dims = 2); dims = 2)
band!(pdp, x, μ .- r, μ .+ r; color = pdp.background[])
end
plfunc! = pdp.stairs[] ? stairs! : lines!
plfunc!(pdp, pdp.attributes, x, μ)
return pdp
end

end
12 changes: 12 additions & 0 deletions SDeMo/src/SDeMo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ export counterfactual
include("explanations/partialresponse.jl")
export partialresponse

include("explanations/ceterisparibus.jl")
export ceterisparibus

include("explanations/shapley.jl")
export explain

Expand All @@ -148,4 +151,13 @@ export writesdm, loadsdm
include("utilities/varia.jl")
export iqr

# Figures
function cpplot end
function cpplot! end
function iceplot end
function iceplot! end
function partialdependenceplot end
function partialdependenceplot! end
export cpplot, cpplot!, iceplot, iceplot!, partialdependenceplot, partialdependenceplot!

end # module SDeMo
37 changes: 37 additions & 0 deletions SDeMo/src/explanations/ceterisparibus.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
ceterisparibus(model::T, instance::Integer, feature::Integer; bins=50, kwargs...)

TODO

All keyword arguments are passed to `predict`.
"""
function ceterisparibus(model::T, i::Integer, feature::Integer, args...; bins::Integer=50, kwargs...) where {T <: AbstractSDM}
X = instance(model, i; strict = false)
x = collect(LinRange(extrema(features(model, feature))..., bins))
Y = permutedims(repeat(X', length(x)))
Y[feature, :] .= x
y = predict(model, Y; kwargs...)
return (x, y)
end

"""
ceterisparibus(model::T, instance::Integer, feature1::Integer, feature2::Integer; bins=50, kwargs...)

TODO

All keyword arguments are passed to `predict`.
"""
function ceterisparibus(model::T, i::Integer, feature1::Integer, feature2::Integer, args...; bins::Integer=50, kwargs...) where {T <: AbstractSDM}
x1 = collect(LinRange(extrema(features(model, feature1))..., bins))
x2 = collect(LinRange(extrema(features(model, feature2))..., bins))
Y = zeros(bins, bins)
X = instance(model, i; strict = false)
for i in eachindex(x1)
X[f1] = x1[i]
for j in eachindex(x2)
X[f2] = x2[j]
Y[i,j] = predict(model, X; kwargs...)
end
end
return (x1, x2, Y)
end
14 changes: 7 additions & 7 deletions SDeMo/src/explanations/partialresponse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function _fill_partialresponse_data!(nx, model::T, variable, inflated) where {T
end

"""
partialresponse(model::T, i::Integer, args...; inflated::Bool, kwargs...)
partialresponse(model::T, i::Integer; bins=50, inflated::Bool, kwargs...)

This method returns the partial response of applying the trained model to a
simulated dataset where all variables *except* `i` are set to their mean value.
Expand All @@ -45,14 +45,14 @@ The different arguments that can follow the variable position are

All keyword arguments are passed to `predict`.
"""
function partialresponse(model::T, i::Integer, args...; inflated::Bool=false, kwargs...) where {T <: AbstractSDM}
nx = SDeMo._make_partialresponse_data(model, i, args...)
function partialresponse(model::T, i::Integer, args...; bins::Integer=50, inflated::Bool=false, kwargs...) where {T <: AbstractSDM}
nx = SDeMo._make_partialresponse_data(model, i, bins)
SDeMo._fill_partialresponse_data!(nx, model, i, inflated)
return (nx[i,:], predict(model, nx; kwargs...))
end

"""
partialresponse(model::T, i::Integer, j::Integer, s::Tuple=(50, 50); inflated::Bool, kwargs...)
partialresponse(model::T, i::Integer, j::Integer; bins=50, inflated::Bool, kwargs...)

This method returns the partial response of applying the trained model to a
simulated dataset where all variables *except* `i` and `j` are set to their mean
Expand All @@ -64,9 +64,9 @@ and `j`, the size of which is given by the last argument `s` (defaults to 50 ×

All keyword arguments are passed to `predict`.
"""
function partialresponse(model::T, i::Integer, j::Integer, s::Tuple=(50, 50); inflated::Bool=false, kwargs...) where {T <: AbstractSDM}
irange = LinRange(extrema(features(model, i))..., s[1])
jrange = LinRange(extrema(features(model, j))..., s[2])
function partialresponse(model::T, i::Integer, j::Integer, s::Tuple=(50, 50); bins::Integer=50, inflated::Bool=false, kwargs...) where {T <: AbstractSDM}
irange = LinRange(extrema(features(model, i))..., bins)
jrange = LinRange(extrema(features(model, j))..., bins)

nx = zeros(eltype(features(model)), size(features(model), 1), length(irange)*length(jrange))

Expand Down
Loading
Loading