Skip to content

Commit b983962

Browse files
committed
allow referring a column multiple time in some functions + add op option to closejoin + cleanup closejoin
allow referring a column multiple times in `byrow` `transpose` joins
1 parent bad1006 commit b983962

File tree

9 files changed

+302
-253
lines changed

9 files changed

+302
-253
lines changed

src/byrow/byrow.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function expand_Base_Fix(f, f2)
2929
end
3030

3131
function byrow(ds::AbstractDataset, ::typeof(any), cols::MultiColumnIndex = :; by = isequal(true), threads = nrow(ds) > __NCORES*10, mapformats = false)
32-
colsidx = index(ds)[cols]
32+
colsidx = multiple_getindex(index(ds), cols)
3333
if by isa AbstractVector
3434
if mapformats
3535
by = map((x,y)->expand_Base_Fix(x, getformat(ds, y)), by, colsidx)
@@ -46,7 +46,7 @@ end
4646
byrow(ds::AbstractDataset, ::typeof(any), col::ColumnIndex; by = isequal(true), threads = nrow(ds) > __NCORES*10, mapformats = false) = byrow(ds, any, [col]; by = by, threads = threads, mapformats = mapformats)
4747

4848
function byrow(ds::AbstractDataset, ::typeof(all), cols::MultiColumnIndex = :; by = isequal(true), threads = nrow(ds) > __NCORES*10, mapformats = false)
49-
colsidx = index(ds)[cols]
49+
colsidx = multiple_getindex(index(ds), cols)
5050
if by isa AbstractVector
5151
if mapformats
5252
by = map((x,y)->expand_Base_Fix(x, getformat(ds, y)), by, colsidx)
@@ -146,7 +146,7 @@ byrow(ds::AbstractDataset, ::typeof(hash), col::ColumnIndex; by = identity, thre
146146
byrow(ds::AbstractDataset, ::typeof(mapreduce), cols::MultiColumnIndex = names(ds, Union{Missing, Number}); op = .+, f = identity, init = missings(mapreduce(eltype, promote_type, view(_columns(ds),index(ds)[cols])), nrow(ds)), kwargs...) = mapreduce(f, op, eachcol(ds[!, cols]), init = init; kwargs...)
147147

148148
function byrow(ds::AbstractDataset, f::Function, cols::MultiColumnIndex; threads = nrow(ds)>1000)
149-
colsidx = index(ds)[cols]
149+
colsidx = multiple_getindex(index(ds), cols)
150150
length(colsidx) == 1 && return byrow(ds, f, colsidx[1]; threads = threads)
151151
threads ? hp_row_generic(ds, f, cols) : row_generic(ds, f, cols)
152152
end

src/byrow/hp_row_functions.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function hp_row_sort(ds::AbstractDataset, cols = names(ds, Union{Missing, Number
2626
end
2727

2828
function hp_row_generic(ds::AbstractDataset, f::Function, cols::MultiColumnIndex)
29-
colsidx = index(ds)[cols]
29+
colsidx = multiple_getindex(index(ds), cols)
3030
if length(colsidx) == 2
3131
try
3232
allowmissing(f.(_columns(ds)[colsidx[1]], _columns(ds)[colsidx[2]]))

src/byrow/row_functions.jl

+21-21
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Base.@propagate_inbounds function _op_for_sum!(x, y, f, lo, hi)
3232
end
3333

3434
function row_sum(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}); threads = true)
35-
colsidx = index(ds)[cols]
35+
colsidx = multiple_getindex(index(ds), cols)
3636
CT = mapreduce(eltype, promote_type, view(_columns(ds),colsidx))
3737
T = Core.Compiler.return_type(f, (CT,))
3838
init0 = missings(T, nrow(ds))
@@ -62,7 +62,7 @@ Base.@propagate_inbounds function _op_for_prod!(x, y, f, lo, hi)
6262
end
6363

6464
function row_prod(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}); threads = true)
65-
colsidx = index(ds)[cols]
65+
colsidx = multiple_getindex(index(ds), cols)
6666
CT = mapreduce(eltype, promote_type, view(_columns(ds),colsidx))
6767
T = Core.Compiler.return_type(f, (CT,))
6868
init0 = missings(T, nrow(ds))
@@ -91,7 +91,7 @@ end
9191

9292

9393
function row_count(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}); threads = true)
94-
colsidx = index(ds)[cols]
94+
colsidx = multiple_getindex(index(ds), cols)
9595
init0 = zeros(Int32, size(ds,1))
9696

9797
if threads
@@ -118,7 +118,7 @@ Base.@propagate_inbounds function op_for_any!(x, y, f, lo, hi)
118118
end
119119

120120
function row_any(ds::AbstractDataset, f::Union{AbstractVector{<:Function}, Function}, cols = :; threads = true)
121-
colsidx = index(ds)[cols]
121+
colsidx = multiple_getindex(index(ds), cols)
122122
init0 = zeros(Bool, size(ds,1))
123123

124124
multi_f = false
@@ -160,7 +160,7 @@ Base.@propagate_inbounds function op_for_all!(x, y, f, lo, hi)
160160
end
161161

162162
function row_all(ds::AbstractDataset, f::Union{AbstractVector{<:Function}, Function}, cols = :; threads = true)
163-
colsidx = index(ds)[cols]
163+
colsidx = multiple_getindex(index(ds), cols)
164164
init0 = ones(Bool, size(ds,1))
165165

166166
multi_f = false
@@ -204,7 +204,7 @@ end
204204

205205

206206
function row_isequal(ds::AbstractDataset, cols = :; by::Union{AbstractVector, DatasetColumn, SubDatasetColumn, ColumnIndex, Nothing} = nothing, threads = true)
207-
colsidx = index(ds)[cols]
207+
colsidx = multiple_getindex(index(ds), cols)
208208
if !(by isa ColumnIndex) && by !== nothing
209209
@assert length(by) == nrow(ds) "to compare values of selected columns in each row, the length of the passed vector and the number of rows must match"
210210
end
@@ -250,7 +250,7 @@ function row_isless(ds::AbstractDataset, cols, colselector::Union{AbstractVector
250250
if !(colselector isa ColumnIndex)
251251
@assert length(colselector) == nrow(ds) "to compare values of selected columns in each row, the length of the passed vector and the number of rows must match"
252252
end
253-
colsidx = index(ds)[cols]
253+
colsidx = multiple_getindex(index(ds), cols)
254254
if colselector isa SubDatasetColumn || colselector isa DatasetColumn
255255
colselector = __!(colselector)
256256
elseif colselector isa ColumnIndex
@@ -314,7 +314,7 @@ function row_findfirst(ds::AbstractDataset, f, cols = names(ds, Union{Missing, N
314314
elseif item isa ColumnIndex
315315
item = _columns(ds)[index(ds)[item]]
316316
end
317-
colsidx = index(ds)[cols]
317+
colsidx = multiple_getindex(index(ds), cols)
318318

319319
colnames_pa = allowmissing(PooledArray(_names(ds)[colsidx]))
320320
push!(colnames_pa, missing)
@@ -353,7 +353,7 @@ function row_findlast(ds::AbstractDataset, f, cols = names(ds, Union{Missing, Nu
353353
elseif item isa ColumnIndex
354354
item = _columns(ds)[index(ds)[item]]
355355
end
356-
colsidx = index(ds)[cols]
356+
colsidx = multiple_getindex(index(ds), cols)
357357
colnames_pa = allowmissing(PooledArray(_names(ds)[colsidx]))
358358
push!(colnames_pa, missing)
359359
missref = get(colnames_pa.invpool, missing, 0)
@@ -429,7 +429,7 @@ function row_select(ds::AbstractDataset, cols, colselector::Union{AbstractVector
429429
if !(colselector isa ColumnIndex)
430430
@assert length(colselector) == nrow(ds) "to pick values of selected columns in each row, the length of the column names and the number of rows must match, i.e. the length of the vector passed as `by` must be $(nrow(ds))."
431431
end
432-
colsidx = index(ds)[cols]
432+
colsidx = multiple_getindex(index(ds), cols)
433433
CT = mapreduce(eltype, promote_type, view(_columns(ds),colsidx))
434434
if colselector isa SubDatasetColumn || colselector isa DatasetColumn
435435
colselector = __!(colselector)
@@ -481,7 +481,7 @@ function row_fill!(ds::AbstractDataset, cols, val::Union{AbstractVector, Dataset
481481
if !(val isa ColumnIndex)
482482
@assert length(val) == nrow(ds) "to fill values in each row, the length of passed values and the number of rows must match."
483483
end
484-
colsidx = index(ds)[cols]
484+
colsidx = multiple_getindex(index(ds), cols)
485485
if val isa SubDatasetColumn || val isa DatasetColumn
486486
val = __!(val)
487487
end
@@ -527,7 +527,7 @@ end
527527

528528

529529
function row_coalesce(ds::AbstractDataset, cols = names(ds, Union{Missing, Number}); threads = true)
530-
colsidx = index(ds)[cols]
530+
colsidx = multiple_getindex(index(ds), cols)
531531
CT = mapreduce(eltype, promote_type, view(_columns(ds),colsidx))
532532

533533
init0 = fill!(Vector{Union{Missing, CT}}(undef, size(ds,1)), missing)
@@ -566,7 +566,7 @@ Base.@propagate_inbounds function _op_for_max!(x, y, f, lo, hi)
566566
end
567567

568568
function row_minimum(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}); threads = true)
569-
colsidx = index(ds)[cols]
569+
colsidx = multiple_getindex(index(ds), cols)
570570
CT = mapreduce(eltype, promote_type, view(_columns(ds),colsidx))
571571
T = Core.Compiler.return_type(f, (CT,))
572572
init0 = missings(T, nrow(ds))
@@ -586,7 +586,7 @@ end
586586
row_minimum(ds::AbstractDataset, cols = names(ds, Union{Missing, Number}); threads = true) = row_minimum(ds, identity, cols; threads = threads)
587587

588588
function row_maximum(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}); threads = true)
589-
colsidx = index(ds)[cols]
589+
colsidx = multiple_getindex(index(ds), cols)
590590
CT = mapreduce(eltype, promote_type, view(_columns(ds),colsidx))
591591
T = Core.Compiler.return_type(f, (CT,))
592592
init0 = missings(T, nrow(ds))
@@ -616,7 +616,7 @@ Base.@propagate_inbounds function _op_for_argminmax!(x, y, f, vals, idx, missref
616616
end
617617

618618
function row_argmin(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}); threads = true)
619-
colsidx = index(ds)[cols]
619+
colsidx = multiple_getindex(index(ds), cols)
620620
minvals = row_minimum(ds, f, cols)
621621
colnames_pa = allowmissing(PooledArray(_names(ds)[colsidx]))
622622
push!(colnames_pa, missing)
@@ -642,7 +642,7 @@ end
642642
row_argmin(ds::AbstractDataset, cols = names(ds, Union{Missing, Number}); threads = true) = row_argmin(ds, identity, cols, threads = threads)
643643

644644
function row_argmax(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}); threads = true)
645-
colsidx = index(ds)[cols]
645+
colsidx = multiple_getindex(index(ds), cols)
646646
maxvals = row_maximum(ds, f, cols)
647647
colnames_pa = allowmissing(PooledArray(_names(ds)[colsidx]))
648648
push!(colnames_pa, missing)
@@ -689,7 +689,7 @@ end
689689
# TODO needs type stability
690690
# TODO need abs2 for calculations
691691
function row_var(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}); dof = true, threads = true)
692-
colsidx = index(ds)[cols]
692+
colsidx = multiple_getindex(index(ds), cols)
693693
CT = mapreduce(eltype, promote_type, view(_columns(ds),colsidx))
694694
T = Core.Compiler.return_type(f, (CT,))
695695
_sq_(x) = x^2
@@ -1026,7 +1026,7 @@ Base.@propagate_inbounds function _op_for_issorted_rev!(x, y, res, lt, lo, hi)
10261026
end
10271027

10281028
function row_issorted(ds::AbstractDataset, cols; rev = false, lt = isless, threads = true)
1029-
colsidx = index(ds)[cols]
1029+
colsidx = multiple_getindex(index(ds), cols)
10301030
init0 = ones(Bool, nrow(ds))
10311031

10321032
if threads
@@ -1068,7 +1068,7 @@ function _fill_dict_and_add!(init0, dict, prehashed, n, p)
10681068
end
10691069

10701070
function row_nunique(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}); count_missing = true)
1071-
colsidx = index(ds)[cols]
1071+
colsidx = multiple_getindex(index(ds), cols)
10721072
prehashed = Matrix{_Prehashed}(undef, size(ds,1), length(colsidx))
10731073
allcols = view(_columns(ds),colsidx)
10741074

@@ -1095,7 +1095,7 @@ Base.@propagate_inbounds function _op_for_hash!(x, y, f, lo, hi)
10951095
end
10961096

10971097
function row_hash(ds::AbstractDataset, f::Function, cols = :; threads = true)
1098-
colsidx = index(ds)[cols]
1098+
colsidx = multiple_getindex(index(ds), cols)
10991099
init0 = zeros(UInt64, nrow(ds))
11001100

11011101
if threads
@@ -1124,7 +1124,7 @@ function _fill_matrix!(inmat, all_data, rows, cols)
11241124
end
11251125

11261126
function row_generic(ds::AbstractDataset, f::Function, cols::MultiColumnIndex)
1127-
colsidx = index(ds)[cols]
1127+
colsidx = multiple_getindex(index(ds), cols)
11281128
if length(colsidx) == 2
11291129
try
11301130
allowmissing(f.(_columns(ds)[colsidx[1]], _columns(ds)[colsidx[2]]))

src/dataset/transpose.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ Base.transpose(::Dataset, cols; [id , renamecolid , renamerowid , variable_name,
143143

144144
function ds_transpose(ds, cols::Union{Tuple, MultiColumnIndex}; id = nothing, renamecolid = nothing, renamerowid = _default_renamerowid_function, variable_name = "_variables_", threads = true, mapformats = true)
145145
if cols isa Tuple
146-
tcols = [cols[j] isa ColumnIndex ? index(ds)[[cols[j]]] : index(ds)[cols[j]] for j in 1:length(cols)]
146+
tcols = [cols[j] isa ColumnIndex ? index(ds)[[cols[j]]] : multiple_getindex(index(ds), cols[j]) for j in 1:length(cols)]
147147
else
148-
tcols = [index(ds)[cols]]
148+
tcols = [multiple_getindex(index(ds), cols)]
149149
end
150150
max_num_col = maximum(length, tcols)
151151
if variable_name isa AbstractString || variable_name isa Symbol || variable_name === nothing
@@ -157,7 +157,7 @@ function ds_transpose(ds, cols::Union{Tuple, MultiColumnIndex}; id = nothing, re
157157
throw(ArgumentError("`variable_name` must be a string, symbol, nothing, or a vector of them"))
158158
end
159159
if id !== nothing
160-
ididx = index(ds)[id]
160+
ididx = multiple_getindex(index(ds), id)
161161

162162
if renamecolid === nothing
163163
renamecolid = _default_renamecolid_function_withid
@@ -422,9 +422,9 @@ end
422422

423423
function ds_transpose(ds::Union{Dataset, GroupBy, GatherBy}, cols::Union{Tuple, MultiColumnIndex}, gcols::MultiColumnIndex; id = nothing, renamecolid = nothing, renamerowid = _default_renamerowid_function, variable_name = "_variables_", default_fill = missing, threads = true, mapformats = true)
424424
if cols isa Tuple
425-
tcols = [cols[j] isa ColumnIndex ? index(ds)[[cols[j]]] : index(ds)[cols[j]] for j in 1:length(cols)]
425+
tcols = [cols[j] isa ColumnIndex ? index(ds)[[cols[j]]] : multiple_getindex(index(ds), cols[j]) for j in 1:length(cols)]
426426
else
427-
tcols = [index(ds)[cols]]
427+
tcols = [multiple_getindex(index(ds), cols)]
428428
end
429429
max_num_col = maximum(length, tcols)
430430
gcolsidx = gcols
@@ -461,7 +461,7 @@ function ds_transpose(ds::Union{Dataset, GroupBy, GatherBy}, cols::Union{Tuple,
461461
end
462462
end
463463
if var_name[j] !== nothing
464-
_repeat_row_names = allowmissing(PooledArray(renamerowid.(names(ds, sel_cols))))
464+
_repeat_row_names = allowmissing(PooledArray(renamerowid.(names(ds)[sel_cols])))
465465
_extend_repeat_row_names!(_repeat_row_names, max_num_col)
466466
_repeat_row_names.refs = repeat(_repeat_row_names.refs, nrow(ds))
467467
new_var_label = Symbol(var_name[j])
@@ -471,7 +471,7 @@ function ds_transpose(ds::Union{Dataset, GroupBy, GatherBy}, cols::Union{Tuple,
471471
_fill_col_val!(res, ECol, length(sel_cols), max_num_col, nrow(ds), _get_perms(ds), threads)
472472
local new_col_id
473473
try
474-
new_col_id = Symbol(renamecolid(1, names(ds, sel_cols)))
474+
new_col_id = Symbol(renamecolid(1, names(ds)[sel_cols]))
475475
catch e
476476
if (e isa MethodError)
477477
new_col_id = Symbol(renamecolid(1))
@@ -485,7 +485,7 @@ function ds_transpose(ds::Union{Dataset, GroupBy, GatherBy}, cols::Union{Tuple,
485485

486486
end
487487
if id !== nothing
488-
ididx = index(ds)[id]
488+
ididx = multiple_getindex(index(ds), id)
489489
if renamecolid === nothing
490490
renamecolid = _default_renamecolid_function_withid
491491
end

0 commit comments

Comments
 (0)