diff --git a/SDeMo/CHANGELOG.md b/SDeMo/CHANGELOG.md index 5eb133d5d0..3cc8546f6f 100644 --- a/SDeMo/CHANGELOG.md +++ b/SDeMo/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## `v1.8.1` + +- **added** a `ceterisparibus` explanation function for CP-plots +- **added** explanatory plots for trained models + ## `v1.8.0` - **added** an overload of `copy` for SDMs diff --git a/SDeMo/Project.toml b/SDeMo/Project.toml index 5d02a8d021..8f1fe76cb5 100644 --- a/SDeMo/Project.toml +++ b/SDeMo/Project.toml @@ -1,6 +1,6 @@ name = "SDeMo" uuid = "3e5feb82-bcca-434d-9cd5-c11731a21467" -version = "1.8.0" +version = "1.8.1" authors = ["Timothée Poisot "] [deps] diff --git a/SDeMo/ext/MakieExtension.jl b/SDeMo/ext/MakieExtension.jl index 8dcb49a15d..5f14d5cad2 100644 --- a/SDeMo/ext/MakieExtension.jl +++ b/SDeMo/ext/MakieExtension.jl @@ -1,11 +1,179 @@ module MakieExtension -import Makie +using Makie +using Statistics using SDeMo +import SDeMo: + cpplot, cpplot!, iceplot, iceplot!, partialdependenceplot, partialdependenceplot! function Makie.convert_arguments(P::Makie.PointBased, model::T) where {T <: AbstractSDM} @assert isgeoreferenced(model) return Makie.convert_arguments(P, model.coordinates) end +function _shared_argument_cp_plots() + return Makie.@DocumentedAttributes begin + bins = 30 + color = @inherit color :black + linestyle = @inherit linestyle :solid + linewidth = @inherit linewidth 1 + alpha = @inherit alpha 1.0 + stairs = true + center = :none # Alt values :midpoint, :value + colormap = @inherit colormap :viridis + colorrange = @inherit colorrange (0, 1) + end +end + +# CP plot + +function _check_trained(model) + if !istrained(model) + throw(UntrainedModelError()) + end + return nothing +end + +function _cpplot_data(model, inst, feat, bins) + _check_trained(model) + X = instance(model, inst; strict = false) + x = collect(LinRange(extrema(features(model, feat))..., bins)) + Y = permutedims(repeat(X', length(x))) + Y[feat, :] .= x + y = predict(model, Y; threshold = false) + return (x, y) +end + +function _cpplot_data(model, inst, f1, f2, bins) + _check_trained(model) + x1 = collect(LinRange(extrema(features(model, f1))..., bins)) + x2 = collect(LinRange(extrema(features(model, f2))..., bins)) + Y = zeros(bins, bins) + X = instance(model, inst; strict = false) + for i in eachindex(x1) + X[f1] = x1[i] + for j in eachindex(x2) + X[f2] = x2[j] + Y[i,j] = predict(model, X; threshold=false) + end + end + return (x1, x2, Y) +end + +Makie.@recipe CPPlot (sdm, instance, feature) begin + _shared_argument_cp_plots()... +end + +const OneDimCP = CPPlot{<:Tuple{AbstractSDM, Integer, Integer}} +const TwoDimCP = CPPlot{<:Tuple{AbstractSDM, <:Integer, <:Integer, <:Integer}} + +Makie.convert_arguments(::Type{OneDimCP}, sdm::AbstractSDM, inst::Integer, feat::Integer) = (sdm, inst, feat) +Makie.convert_arguments(::Type{TwoDimCP}, sdm::AbstractSDM, inst::Integer, f1::Integer, f2::Integer) = (sdm, inst, [f1, f2]) + +function Makie.plot!(cp::OneDimCP) + x, y = _cpplot_data(cp.sdm[], cp.instance[], cp.feature[], cp.bins[]) + if cp.center[] == :midpoint + x .-= (x[end] + x[begin]) / 2 + end + if cp.center[] == :value + x .-= instance(cp.sdm[], cp.instance[]; strict = false)[cp.feature[]] + end + plfunc! = cp.stairs[] ? stairs! : lines! + plfunc!(cp, cp.attributes, x, y) + return cp +end + +function Makie.plot!(cp::TwoDimCP) + model = cp.arg1[] + inst = cp.arg2[] + f1 = cp.arg3[] + f2 = cp.arg4[] + x1, x2, y = _cpplot_data(model, inst, f1, f2, cp.bins[]) + if cp.center[] == :midpoint + x1 .-= (x1[end] + x1[begin]) / 2 + x2 .-= (x2[end] + x2[begin]) / 2 + end + if cp.center[] == :value + x1 .-= instance(model, inst; strict = false)[f1] + x2 .-= instance(model, inst; strict = false)[f2] + end + heatmap!(cp, cp.attributes, x1, x2, y) + return cp +end + +# ICE plot + +Makie.@recipe ICEPlot (sdm, instances, feature) begin + _shared_argument_cp_plots()... +end + +Makie.plottype(::ICEPlot) = Lines +Makie.convert_arguments(::Type{ICEPlot}, sdm::AbstractSDM, x::Colon, y::Int) = + (sdm, eachindex(labels(sdm)), y) +Makie.convert_arguments(::Type{ICEPlot}, sdm::AbstractSDM, x::AbstractRange, y::Int) = + (sdm, collect(x), y) +Makie.convert_arguments(::Type{ICEPlot}, sdm::AbstractSDM, x::Vector{Int}, y::Int) = + (sdm, x, y) +Makie.convert_arguments(::Type{ICEPlot}, sdm::AbstractSDM, y::Int) = + (sdm, eachindex(labels(sdm)), y) + +function Makie.plot!(ice::ICEPlot) + for i in ice.instances[] + cpplot!(ice, ice.attributes, ice.sdm[], i, ice.feature[]) + end + return ice +end + +# PCP plot + +Makie.@recipe PartialDependencePlot (sdm, instances, feature) begin + _shared_argument_cp_plots()... + ribbon = nothing # Or a function for the ribbon + background = :grey80 +end + +Makie.plottype(::PartialDependencePlot) = Lines +Makie.convert_arguments(::Type{PartialDependencePlot}, sdm::AbstractSDM, x::Colon, y::Int) = + (sdm, eachindex(labels(sdm)), y) +Makie.convert_arguments( + ::Type{PartialDependencePlot}, + sdm::AbstractSDM, + x::AbstractRange, + y::Int, +) = (sdm, collect(x), y) +Makie.convert_arguments( + ::Type{PartialDependencePlot}, + sdm::AbstractSDM, + x::Vector{Int}, + y::Int, +) = (sdm, x, y) +Makie.convert_arguments( + ::Type{PartialDependencePlot}, + sdm::AbstractSDM, + y::Int, +) = (sdm, eachindex(labels(sdm)), y) + +function Makie.plot!(pdp::PartialDependencePlot) + x, _ = _cpplot_data(pdp.sdm[], 1, pdp.feature[], pdp.bins[]) + Y = zeros(Float64, pdp.bins[], length(pdp.instances[])) + for (i, inst) in enumerate(pdp.instances[]) + xᵢ, yᵢ = _cpplot_data(pdp.sdm[], inst, pdp.feature[], pdp.bins[]) + Y[:, i] = yᵢ + end + if pdp.center[] == :midpoint + x .-= (x[end] + x[begin]) / 2 + end + if pdp.center[] == :value + @warn "Using center = :value has no effect on a partial dependence plot" + end + μ = dropdims(mapslices(Statistics.mean, Y; dims = 2); dims = 2) + if !isnothing(pdp.ribbon[]) + r = dropdims(mapslices(pdp.ribbon[], Y; dims = 2); dims = 2) + band!(pdp, x, μ .- r, μ .+ r; color = pdp.background[]) + end + plfunc! = pdp.stairs[] ? stairs! : lines! + plfunc!(pdp, pdp.attributes, x, μ) + return pdp +end + end \ No newline at end of file diff --git a/SDeMo/src/SDeMo.jl b/SDeMo/src/SDeMo.jl index 235184938f..b3d0554f52 100644 --- a/SDeMo/src/SDeMo.jl +++ b/SDeMo/src/SDeMo.jl @@ -137,6 +137,9 @@ export counterfactual include("explanations/partialresponse.jl") export partialresponse +include("explanations/ceterisparibus.jl") +export ceterisparibus + include("explanations/shapley.jl") export explain @@ -148,4 +151,13 @@ export writesdm, loadsdm include("utilities/varia.jl") export iqr +# Figures +function cpplot end +function cpplot! end +function iceplot end +function iceplot! end +function partialdependenceplot end +function partialdependenceplot! end +export cpplot, cpplot!, iceplot, iceplot!, partialdependenceplot, partialdependenceplot! + end # module SDeMo diff --git a/SDeMo/src/explanations/ceterisparibus.jl b/SDeMo/src/explanations/ceterisparibus.jl new file mode 100644 index 0000000000..9a085c8447 --- /dev/null +++ b/SDeMo/src/explanations/ceterisparibus.jl @@ -0,0 +1,37 @@ +""" + ceterisparibus(model::T, instance::Integer, feature::Integer; bins=50, kwargs...) + +TODO + +All keyword arguments are passed to `predict`. +""" +function ceterisparibus(model::T, i::Integer, feature::Integer, args...; bins::Integer=50, kwargs...) where {T <: AbstractSDM} + X = instance(model, i; strict = false) + x = collect(LinRange(extrema(features(model, feature))..., bins)) + Y = permutedims(repeat(X', length(x))) + Y[feature, :] .= x + y = predict(model, Y; kwargs...) + return (x, y) +end + +""" + ceterisparibus(model::T, instance::Integer, feature1::Integer, feature2::Integer; bins=50, kwargs...) + +TODO + +All keyword arguments are passed to `predict`. +""" +function ceterisparibus(model::T, i::Integer, feature1::Integer, feature2::Integer, args...; bins::Integer=50, kwargs...) where {T <: AbstractSDM} + x1 = collect(LinRange(extrema(features(model, feature1))..., bins)) + x2 = collect(LinRange(extrema(features(model, feature2))..., bins)) + Y = zeros(bins, bins) + X = instance(model, i; strict = false) + for i in eachindex(x1) + X[f1] = x1[i] + for j in eachindex(x2) + X[f2] = x2[j] + Y[i,j] = predict(model, X; kwargs...) + end + end + return (x1, x2, Y) +end diff --git a/SDeMo/src/explanations/partialresponse.jl b/SDeMo/src/explanations/partialresponse.jl index 683038535b..780cea04c6 100644 --- a/SDeMo/src/explanations/partialresponse.jl +++ b/SDeMo/src/explanations/partialresponse.jl @@ -29,7 +29,7 @@ function _fill_partialresponse_data!(nx, model::T, variable, inflated) where {T end """ - partialresponse(model::T, i::Integer, args...; inflated::Bool, kwargs...) + partialresponse(model::T, i::Integer; bins=50, inflated::Bool, kwargs...) This method returns the partial response of applying the trained model to a simulated dataset where all variables *except* `i` are set to their mean value. @@ -45,14 +45,14 @@ The different arguments that can follow the variable position are All keyword arguments are passed to `predict`. """ -function partialresponse(model::T, i::Integer, args...; inflated::Bool=false, kwargs...) where {T <: AbstractSDM} - nx = SDeMo._make_partialresponse_data(model, i, args...) +function partialresponse(model::T, i::Integer, args...; bins::Integer=50, inflated::Bool=false, kwargs...) where {T <: AbstractSDM} + nx = SDeMo._make_partialresponse_data(model, i, bins) SDeMo._fill_partialresponse_data!(nx, model, i, inflated) return (nx[i,:], predict(model, nx; kwargs...)) end """ - partialresponse(model::T, i::Integer, j::Integer, s::Tuple=(50, 50); inflated::Bool, kwargs...) + partialresponse(model::T, i::Integer, j::Integer; bins=50, inflated::Bool, kwargs...) This method returns the partial response of applying the trained model to a simulated dataset where all variables *except* `i` and `j` are set to their mean @@ -64,9 +64,9 @@ and `j`, the size of which is given by the last argument `s` (defaults to 50 × All keyword arguments are passed to `predict`. """ -function partialresponse(model::T, i::Integer, j::Integer, s::Tuple=(50, 50); inflated::Bool=false, kwargs...) where {T <: AbstractSDM} - irange = LinRange(extrema(features(model, i))..., s[1]) - jrange = LinRange(extrema(features(model, j))..., s[2]) +function partialresponse(model::T, i::Integer, j::Integer, s::Tuple=(50, 50); bins::Integer=50, inflated::Bool=false, kwargs...) where {T <: AbstractSDM} + irange = LinRange(extrema(features(model, i))..., bins) + jrange = LinRange(extrema(features(model, j))..., bins) nx = zeros(eltype(features(model)), size(features(model), 1), length(irange)*length(jrange)) diff --git a/docs/src/howto/dataviz/models.jl b/docs/src/howto/dataviz/models.jl index 121d583f4a..1b3cb9c31f 100644 --- a/docs/src/howto/dataviz/models.jl +++ b/docs/src/howto/dataviz/models.jl @@ -9,16 +9,16 @@ CairoMakie.activate!(; type = "png", px_per_unit = 2) #hide # We will use the demonstration data from the `SDeMo` package: -model = SDM(RawData, NaiveBayes, SDeMo.__demodata()...) +model = SDM(RawData, Logistic, SDeMo.__demodata()...) # ## Plotting instances -scatter(model) +scatter(model; axis=(; aspect=DataAspect())) # This can be coupled with information about the model itself, to provide more # interesting visualisations: -scatter(model, color=labels(model)) +scatter(model; color=labels(model), axis=(; aspect=DataAspect())) # ## Model diagnostic plots @@ -26,8 +26,71 @@ scatter(model, color=labels(model)) train!(model) -# ::: info Coming soon -# -# These plots will be included in a future release of `SDeMo`- stay tuned! -# -# ::: \ No newline at end of file +# ### Ceteris paribus + +# The _ceteris paribus_ plot allows seeing the effect of all possible values of +# a feature on a specific instance. For example, this is how the prediction for +# instance 4 is affected by a change in the BIO1 variable (mean annual +# temperature). + +# All of these plots have a feature (values of) in the x axis, and the resulting +# prediction on the y axis. + +cpplot(model, 4, 1) + +# This plot can be drawn as a line rather than stairs: + +cpplot(model, 4, 1; stairs=false) + +# The line attributes can be changed: + +cpplot(model, 4, 1; stairs=false, linewidth=2, color=:red, linestyle=:dash) + +# The plot can also be presented by centering the x axis to the midpoint value: + +cpplot(model, 4, 1; center=:midpoint) + +# Or to the value of the feature for this specific instance: + +cpplot(model, 4, 1; center=:value) + +# CP plots can also be produced for two variables: + +cpplot(model, 4, 4, 6; colormap=:YlGnBu) + +# ### Individual conditional expectations + +# The ICE plot is the superposition of multiple CP plots. It uses the same +# arguments, but the instances are given as a range or collection. + +# To plot all the instances, we can use: + +iceplot(model, 1) + +# Because this creates a lot of overplotting, it is a good idea to tweak the +# transparency: + +iceplot(model, 1; alpha=0.2, center=:midpoint, stairs=false) + +# The list of instances to use can also be given as a collection: + +iceplot(model, findall(labels(model)), 1; alpha=0.2, stairs=false, color=:darkgreen, label="Presence") +iceplot!(model, findall(!, labels(model)), 1; alpha=0.2, stairs=false, color=:grey50, label="Absence") +axislegend(current_axis()) +current_figure() + +# ### Partial dependence + +# The partial dependence plot is the average of all CP plots. The instances to +# use in it are specified like in the ICE plots. + +partialdependenceplot(model, 1) + +# We can also pass a `ribbon` function to draw a band around the line: + +import Statistics +partialdependenceplot(model, 1; ribbon=Statistics.std, stairs=false, background=:skyblue, color=:darkblue) + +# We can also specify to only run the plot for some instances: + +partialdependenceplot(model, findall(labels(model)), 1; ribbon=Statistics.std, stairs=false, background=:grey95, color=:orange) \ No newline at end of file