@@ -28,7 +28,7 @@ def einsum(equation, *operands):
28
28
Uses uncased letters to specify the dimension of the operands and result. The input
29
29
equation is on the left hand before `->` while the output equation is on the right side.
30
30
Einsum can infer the result shape so that the `->` and the result label letters can be omitted.
31
- Operands in the input equation are splited by commas (','), e.g. 'abc,cde' describes two 3D
31
+ Operands in the input equation are splitted by commas (','), e.g. 'abc,cde' describes two 3D
32
32
operands. The dimensions labeled with same letter should be same or be 1. Ellipsis ('...') can
33
33
be used to specify the broadcast dimensions.
34
34
@@ -129,14 +129,14 @@ def _mul_sum(left, right, sum_dims):
129
129
is_right_summed_dim = right .shape [i ] > 1
130
130
if i in sum_dims_set :
131
131
if is_left_summed_dim and is_right_summed_dim :
132
- assert left .shape [i ] == right .shape [i ], "Non-brocast dim should be equal."
132
+ assert left .shape [i ] == right .shape [i ], "Non-broadcast dim should be equal."
133
133
summed_size *= left .shape [i ]
134
134
elif is_left_summed_dim :
135
135
left = left .sum (axis = i , keepdim = True )
136
136
elif is_right_summed_dim :
137
137
right = right .sum (axis = i , keepdim = True )
138
138
elif is_left_summed_dim and is_right_summed_dim :
139
- assert left .shape [i ] == right .shape [i ], "Non-brocast dim should be equal."
139
+ assert left .shape [i ] == right .shape [i ], "Non-broadcast dim should be equal."
140
140
batch_dims .append (i )
141
141
batch_size *= left .shape [i ]
142
142
elif is_left_summed_dim :
@@ -204,7 +204,7 @@ def _mul_sum(left, right, sum_dims):
204
204
for ch in term :
205
205
if ch == "." :
206
206
ell_char_count += 1
207
- assert ell_char_count <= 3 , "The '.' should only exist in one ellispis '...' in term {}" .format (term )
207
+ assert ell_char_count <= 3 , "The '.' should only exist in one ellipsis '...' in term {}" .format (term )
208
208
if ell_char_count == 3 :
209
209
if num_ell_idxes == - 1 :
210
210
num_ell_idxes = curr_num_ell_idxes
@@ -213,7 +213,7 @@ def _mul_sum(left, right, sum_dims):
213
213
else :
214
214
assert (
215
215
curr_num_ell_idxes == num_ell_idxes
216
- ), "Ellispis in all terms should represent same dimensions ({})." .format (num_ell_idxes )
216
+ ), "Ellipsis in all terms should represent same dimensions ({})." .format (num_ell_idxes )
217
217
218
218
for j in range (num_ell_idxes ):
219
219
curr_operand_idxes .append (j + first_ell_idx )
@@ -247,11 +247,11 @@ def _mul_sum(left, right, sum_dims):
247
247
for ch in output_eqn :
248
248
if ch == "." :
249
249
ell_char_count += 1
250
- assert ell_char_count <= 3 , "The '.' should only exist in one ellispis '...' in term {}" .format (
250
+ assert ell_char_count <= 3 , "The '.' should only exist in one ellipsis '...' in term {}" .format (
251
251
output_eqn
252
252
)
253
253
if ell_char_count == 3 :
254
- assert num_ell_idxes > - 1 , "Input equation '{}' don't have ellispis ." .format (input_eqn )
254
+ assert num_ell_idxes > - 1 , "Input equation '{}' don't have ellipsis ." .format (input_eqn )
255
255
for j in range (num_ell_idxes ):
256
256
idxes_to_output_dims [first_ell_idx + j ] = num_output_dims
257
257
num_output_dims += 1
0 commit comments