Skip to content

Commit 4d02fd3

Browse files
suda-yugammcky
andauthored
Replace np.sum(a * b) with a @ b (#214)
Co-authored-by: Matt McKay <[email protected]>
1 parent 677e00b commit 4d02fd3

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

lectures/calvo_machine_learn.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def compute_θ(μ, α=1):
567567
568568
# Compute the weighted sums for all t
569569
weighted_sums = jnp.array(
570-
[jnp.sum(λ_powers[:T-t] * μ[t:T]) for t in range(T)])
570+
[(λ_powers[:T-t] @ μ[t:T]) for t in range(T)])
571571
572572
# Compute θ values except for the last element
573573
θ = (1 - λ) * weighted_sums + λ**(T - jnp.arange(T)) * μbar
@@ -595,7 +595,7 @@ def compute_V(μ, β, c, α=1, u0=1, u1=0.5, u2=3):
595595
t = np.arange(T)
596596
597597
# Compute sum except for the last element
598-
V_sum = np.sum(β**t * (h0 + h1 * θ[:T] + h2 * θ[:T]**2 - 0.5 * c * μ[:T]**2))
598+
V_sum = (β**t) @ (h0 + h1 * θ[:T] + h2 * θ[:T]**2 - 0.5 * c * μ[:T]**2)
599599
600600
# Compute the final term
601601
V_final = (β**T / (1 - β)) * (h0 + h1 * μ[-1] + h2 * μ[-1]**2 - 0.5 * c * μ[-1]**2)
@@ -931,7 +931,7 @@ def compute_J(μ, β, c, α=1, u0=1, u1=0.5, u2=3):
931931
(β**T/(1-β))])
932932
933933
θ = B @ μ
934-
βθ_sum = jnp.sum((β_vec * h1) * θ)
934+
βθ_sum = (β_vec * h1) @ θ
935935
βθ_square_sum = β_vec * h2 * θ.T @ θ
936936
βμ_square_sum = 0.5 * c * β_vec * μ.T @ μ
937937

0 commit comments

Comments
 (0)