Skip to content

Commit 113c2b5

Browse files
committed
fix: setindex with scalars
1 parent e1b0d83 commit 113c2b5

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

src/ConcreteRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
154154
end
155155

156156
function mysetindex!(a, v, args::Vararg{Int,N}) where {N}
157-
Base.setindex!(a, v, args...)
157+
setindex!(a, v, args...)
158158
return nothing
159159
end
160160

src/TracedRArray.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,17 @@ function Base.getindex(a::WrappedTracedRArray, indices...)
9999
end
100100

101101
function Base.setindex!(
102-
a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon},N}
102+
a::TracedRArray{T,N}, v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int},N}
103103
) where {T,N}
104+
indices = map(indices) do i
105+
i isa Int ? (i:i) : (i isa Colon ? (1:size(a, i)) : i)
106+
end
107+
v = broadcast_to_size(v, length.(indices))
108+
v = promote_to(TracedRArray{T,N}, v)
104109
indices = [
105110
(promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for
106111
i in indices
107112
]
108-
v = promote_to(TracedRArray{T,N}, v)
109113
res = MLIR.IR.result(
110114
MLIR.Dialects.stablehlo.dynamic_update_slice(a.mlir_data, v.mlir_data, indices), 1
111115
)

0 commit comments

Comments
 (0)