-
Notifications
You must be signed in to change notification settings - Fork 22
Closed
Description
Hi @kylebd99,
I think this is the script that reproduces the Galley MTTKRP performance issue I mentioned for the 1.2.0
release:
using Finch
config = (
(100, 25, 100, 10, 0.001),
(100, 25, 100, 100, 0.001),
(1000, 25, 100, 100, 0.001),
(1000, 25, 1000, 100, 0.001),
)
for (I, J, K, L, DENSITY) in config
B_shape = (I, K, L);
B_tensor = fsprand(Float64, I, K, L, DENSITY);
D_tensor = rand(L, J);
C_tensor = rand(K, J);
B_lazy = lazy(swizzle(B_tensor, 1, 2, 3));
D_lazy = lazy(swizzle(Tensor(D_tensor), 1, 2));
C_lazy = lazy(swizzle(Tensor(C_tensor), 1, 2));
plan = sum(
permutedims(
broadcast(*,
permutedims(
permutedims(
broadcast(*,
permutedims(B_lazy[:, :, :, nothing], (4, 3, 2, 1)),
permutedims(D_lazy[nothing, nothing, :, :], (4, 3, 2, 1)),
),
(4, 3, 2, 1),
),
(4, 3, 2, 1),
),
permutedims(C_lazy[nothing, :, nothing, :], (4, 3, 2, 1)),
),
(4, 3, 2, 1),
),
dims=(2, 3),
);
scheduler = Finch.default_scheduler();
result = compute(plan, ctx=scheduler);
t0 = time()
result = compute(plan, ctx=scheduler);
t1 = time()
print("Default - Elapsed: ", t1 - t0, "\n")
scheduler = Finch.galley_scheduler();
result = compute(plan, ctx=scheduler, tag=sum(B_shape));
t0 = time()
result = compute(plan, ctx=scheduler, tag=sum(B_shape));
t1 = time()
print("Galley - Elapsed: ", t1 - t0, "\n\n")
end
And here's the time output between the two versions:
For v1.1.0
Default - Elapsed: 0.010934114
Galley - Elapsed: 0.012537002
Default - Elapsed: 0.000733137
Galley - Elapsed: 0.000170946
Default - Elapsed: 0.008889913
Galley - Elapsed: 0.000409841
Default - Elapsed: 0.056946039
Galley - Elapsed: 0.002696037
For v1.2.0
Default - Elapsed: 0.011469125
Galley - Elapsed: 0.110867977
Default - Elapsed: 0.000752925
Galley - Elapsed: 0.010062932
Default - Elapsed: 0.007393121
Galley - Elapsed: 0.011270999
Default - Elapsed: 0.056483030
Galley - Elapsed: 0.013918876
Metadata
Metadata
Assignees
Labels
No labels