Skip to content

Commit a85081e

Browse files
Danielmatthieugomez
Daniel
andauthored
Fix weights on GPU (#41)
* simplify progressbar * Update FixedEffectSolverCPU.jl * Update progressbar.jl * Update FixedEffectSolverCPU.jl * correct weights * Update Project.toml * Fixing `nthreads` Fix `nthreads` before the loop in `update_weights!` when run on GPU * removed duplicate weights creation Co-authored-by: matthieugomez <[email protected]>
1 parent 28a015d commit a85081e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

Diff for: src/FixedEffectSolvers/FixedEffectSolverGPU.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,12 @@ function AbstractFixedEffectSolver{T}(fes::Vector{<:FixedEffect}, weights::Abstr
127127
h = FixedEffectCoefficients([cuzeros(T, fe.n) for fe in fes])
128128
hbar = FixedEffectCoefficients([cuzeros(T, fe.n) for fe in fes])
129129
tmp = zeros(T, length(weights))
130-
update_weights!(FixedEffectSolverGPU{T}(m, weights, b, r, x, v, h, hbar, tmp, fes), weights)
130+
update_weights!(FixedEffectSolverGPU{T}(m, cuzeros(T, length(weights)), b, r, x, v, h, hbar, tmp, fes), weights)
131131
end
132132

133133

134134
function update_weights!(feM::FixedEffectSolverGPU{T}, weights::AbstractWeights) where {T}
135-
weights = cu(T, collect(weights))
135+
weights = cu(T, weights)
136136
nthreads = feM.m.nthreads
137137
for (scale, fe) in zip(feM.m.scales, feM.m.fes)
138138
scale!(scale, fe.refs, fe.interaction, weights, nthreads)

0 commit comments

Comments
 (0)