@@ -333,11 +333,19 @@ function find_breaks(v::AbstractVector, qs::AbstractVector)
333333 return breaks
334334end
335335
336+ _quantile! (x:: AbstractArray , w:: Nothing , p:: AbstractVector ) =
337+ quantile! (x, p, sorted= true )
338+ # AbstractWeights method is defined in StatsBase extension
339+ # There is no in-place weighted quantile method in StatsBase
340+ _quantile! (x:: AbstractArray , w:: AbstractVector , p:: AbstractVector ) =
341+ throw (ArgumentError (" `weights` must be an `AbstractWeights` vector from StatsBase.jl" ))
342+
336343"""
337344 cut(x::AbstractArray, ngroups::Integer;
338345 labels::Union{AbstractVector{<:AbstractString},Function},
339346 sigdigits::Integer=3,
340- allowempty::Bool=false)
347+ allowempty::Bool=false,
348+ weights::Union{AbstractWeights, Nothing}=nothing)
341349
342350Cut a numeric array into `ngroups` quantiles.
343351
@@ -369,19 +377,39 @@ quantiles.
369377 other than the last one are equal, generating empty intervals;
370378 when `true`, duplicate breaks are allowed and the intervals they generate are kept as
371379 unused levels (but duplicate labels are not allowed).
380+ * `weights::Union{AbstractWeights, Nothing}=nothing`: observations weights to pass to `quantile`.
372381"""
373382function cut (x:: AbstractArray , ngroups:: Integer ;
374383 labels:: Union{AbstractVector{<:SupportedTypes},Function,Nothing} = nothing ,
375384 sigdigits:: Integer = 3 ,
376- allowempty:: Bool = false )
385+ allowempty:: Bool = false ,
386+ weights:: Union{AbstractVector, Nothing} = nothing )
377387 ngroups >= 1 || throw (ArgumentError (" ngroups must be strictly positive (got $ngroups )" ))
378- sorted_x = eltype (x) >: Missing ? sort! (collect (skipmissing (x))) : sort (x)
388+ if weights === nothing
389+ sorted_x = eltype (x) >: Missing ? sort! (collect (skipmissing (x))) : sort (x)
390+ min_x, max_x = first (sorted_x), last (sorted_x)
391+ if (min_x isa Number && isnan (min_x)) ||
392+ (max_x isa Number && isnan (max_x))
393+ throw (ArgumentError (" NaN values are not allowed in input vector" ))
394+ end
395+ else
396+ if eltype (x) >: Missing
397+ nm_inds = findall (! ismissing, x)
398+ nm_x = view (x, nm_inds)
399+ # TODO : use a view once this is supported (JuliaStats/StatsBase.jl#723)
400+ nm_weights = weights[nm_inds]
401+ else
402+ nm_x = x
403+ nm_weights = weights
404+ end
405+ sorted_x = sort (nm_x)
406+ end
379407 min_x, max_x = first (sorted_x), last (sorted_x)
380408 if (min_x isa Number && isnan (min_x)) ||
381409 (max_x isa Number && isnan (max_x))
382410 throw (ArgumentError (" NaN values are not allowed in input vector" ))
383411 end
384- qs = quantile! (sorted_x, (1 : (ngroups- 1 ))/ ngroups, sorted = true )
412+ qs = _quantile! (nm_x, nm_weights, (1 : (ngroups- 1 ))/ ngroups)
385413 breaks = [min_x; find_breaks (sorted_x, qs); max_x]
386414 if ! allowempty && ! allunique (@view breaks[1 : end - 1 ])
387415 throw (ArgumentError (" cannot compute $ngroups quantiles due to " *
0 commit comments