1
- from typing import Any , Optional , Union
1
+ from typing import Union
2
2
3
3
import numpy as np
4
4
from numpy .core .numeric import normalize_axis_index , normalize_axis_tuple
5
5
6
6
from .julia import jl
7
- from .levels import _Display , Dense , Element , Storage
7
+ from .levels import _Display , Dense , Element
8
8
from .typing import OrderType , JuliaObj , spmatrix , TupleOf3Arrays
9
9
10
10
@@ -57,12 +57,13 @@ class Tensor(_Display):
57
57
array([[0, 1, 2],
58
58
[3, 4, 5]])
59
59
"""
60
+
60
61
row_major = "C"
61
62
column_major = "F"
62
63
63
64
def __init__ (
64
65
self ,
65
- obj : Union [np .ndarray , spmatrix , Storage , JuliaObj ],
66
+ obj : Union [np .ndarray , spmatrix , JuliaObj ],
66
67
/ ,
67
68
* ,
68
69
fill_value : np .number = 0.0 ,
@@ -74,7 +75,9 @@ def __init__(
74
75
jl_data = self ._from_numpy (obj , fill_value = fill_value )
75
76
self ._obj = jl_data
76
77
elif isinstance (obj , Storage ): # from-storage constructor
77
- order = self .preprocess_order (obj .order , self .get_lvl_ndim (obj .levels_descr ._obj ))
78
+ order = self .preprocess_order (
79
+ obj .order , self .get_lvl_ndim (obj .levels_descr ._obj )
80
+ )
78
81
self ._obj = jl .swizzle (jl .Tensor (obj .levels_descr ._obj ), * order )
79
82
elif jl .isa (obj , jl .Finch .SwizzleArray ): # raw-Julia-object constructor
80
83
self ._obj = obj
@@ -143,25 +146,22 @@ def _order(self) -> tuple[int, ...]:
143
146
return jl .typeof (self ._obj ).parameters [1 ]
144
147
145
148
@classmethod
146
- def preprocess_order (
147
- cls , order : OrderType , ndim : int
148
- ) -> tuple [int , ...]:
149
+ def preprocess_order (cls , order : OrderType , ndim : int ) -> tuple [int , ...]:
149
150
if order == cls .column_major :
150
151
permutation = tuple (range (1 , ndim + 1 ))
151
152
elif order == cls .row_major or order is None :
152
153
permutation = tuple (range (1 , ndim + 1 )[::- 1 ])
153
154
elif isinstance (order , tuple ):
154
155
if min (order ) == 0 :
155
156
order = tuple (i + 1 for i in order )
156
- if (
157
- len (order ) == ndim and
158
- all ([i in order for i in range (1 , ndim + 1 )])
159
- ):
157
+ if len (order ) == ndim and all ([i in order for i in range (1 , ndim + 1 )]):
160
158
permutation = order
161
159
else :
162
160
raise ValueError (f"Custom order is not a permutation: { order } ." )
163
161
else :
164
- raise ValueError (f"order must be 'C', 'F' or a tuple, but is: { type (order )} ." )
162
+ raise ValueError (
163
+ f"order must be 'C', 'F' or a tuple, but is: { type (order )} ."
164
+ )
165
165
166
166
return permutation
167
167
@@ -214,11 +214,11 @@ def permute_dims(self, axes: tuple[int, ...]) -> "Tensor":
214
214
new_tensor = Tensor (new_obj )
215
215
return new_tensor
216
216
217
- def to_device (self , device : Storage ) -> "Tensor" :
217
+ def to_device (self , device ) -> "Tensor" :
218
218
return Tensor (self ._from_other_tensor (self , storage = device ))
219
219
220
220
@classmethod
221
- def _from_other_tensor (cls , tensor : "Tensor" , storage : Optional [ Storage ] ) -> JuliaObj :
221
+ def _from_other_tensor (cls , tensor : "Tensor" , storage ) -> JuliaObj :
222
222
order = cls .preprocess_order (storage .order , tensor .ndim )
223
223
return jl .swizzle (
224
224
jl .Tensor (storage .levels_descr ._obj , tensor ._obj .body ), * order
@@ -239,7 +239,10 @@ def _from_numpy(cls, arr: np.ndarray, fill_value: np.number) -> JuliaObj:
239
239
def _from_scipy_sparse (cls , x ) -> JuliaObj :
240
240
if x .format == "coo" :
241
241
return cls .construct_coo_jl_object (
242
- coords = (x .col , x .row ), data = x .data , shape = x .shape [::- 1 ], order = Tensor .row_major
242
+ coords = (x .col , x .row ),
243
+ data = x .data ,
244
+ shape = x .shape [::- 1 ],
245
+ order = Tensor .row_major ,
243
246
)
244
247
elif x .format == "csc" :
245
248
return cls .construct_csc_jl_object (
@@ -255,7 +258,9 @@ def _from_scipy_sparse(cls, x) -> JuliaObj:
255
258
raise ValueError (f"Unsupported SciPy format: { type (x )} " )
256
259
257
260
@classmethod
258
- def construct_coo_jl_object (cls , coords , data , shape , order , fill_value = 0.0 ) -> JuliaObj :
261
+ def construct_coo_jl_object (
262
+ cls , coords , data , shape , order , fill_value = 0.0
263
+ ) -> JuliaObj :
259
264
assert len (coords ) == 2
260
265
ndim = len (shape )
261
266
order = cls .preprocess_order (order , ndim )
@@ -264,12 +269,18 @@ def construct_coo_jl_object(cls, coords, data, shape, order, fill_value=0.0) ->
264
269
ptr = jl .Vector [jl .Int ]([1 , len (data ) + 1 ])
265
270
tbl = tuple (jl .PlusOneVector (arr ) for arr in coords )
266
271
267
- jl_data = jl .swizzle (jl .Tensor (jl .SparseCOO [ndim ](lvl , shape , ptr , tbl )), * order )
272
+ jl_data = jl .swizzle (
273
+ jl .Tensor (jl .SparseCOO [ndim ](lvl , shape , ptr , tbl )), * order
274
+ )
268
275
return jl_data
269
276
270
277
@classmethod
271
- def construct_coo (cls , coords , data , shape , order = row_major , fill_value = 0.0 ) -> "Tensor" :
272
- return Tensor (cls .construct_coo_jl_object (coords , data , shape , order , fill_value ))
278
+ def construct_coo (
279
+ cls , coords , data , shape , order = row_major , fill_value = 0.0
280
+ ) -> "Tensor" :
281
+ return Tensor (
282
+ cls .construct_coo_jl_object (coords , data , shape , order , fill_value )
283
+ )
273
284
274
285
@staticmethod
275
286
def _construct_compressed2d_jl_object (
@@ -288,22 +299,27 @@ def _construct_compressed2d_jl_object(
288
299
289
300
lvl = jl .Element (dtype (fill_value ), data )
290
301
jl_data = jl .swizzle (
291
- jl .Tensor (jl .Dense (jl .SparseList (lvl , shape [0 ], indptr , indices ), shape [1 ])), * order
302
+ jl .Tensor (
303
+ jl .Dense (jl .SparseList (lvl , shape [0 ], indptr , indices ), shape [1 ])
304
+ ),
305
+ * order ,
292
306
)
293
307
return jl_data
294
308
295
309
@classmethod
296
- def construct_csc_jl_object (cls , arg : TupleOf3Arrays , shape : tuple [ int , ...]) -> JuliaObj :
297
- return cls . _construct_compressed2d_jl_object (
298
- arg = arg , shape = shape , order = ( 1 , 2 )
299
- )
310
+ def construct_csc_jl_object (
311
+ cls , arg : TupleOf3Arrays , shape : tuple [ int , ...]
312
+ ) -> JuliaObj :
313
+ return cls . _construct_compressed2d_jl_object ( arg = arg , shape = shape , order = ( 1 , 2 ) )
300
314
301
315
@classmethod
302
316
def construct_csc (cls , arg : TupleOf3Arrays , shape : tuple [int , ...]) -> "Tensor" :
303
317
return Tensor (cls .construct_csc_jl_object (arg , shape ))
304
318
305
319
@classmethod
306
- def construct_csr_jl_object (cls , arg : TupleOf3Arrays , shape : tuple [int , ...]) -> JuliaObj :
320
+ def construct_csr_jl_object (
321
+ cls , arg : TupleOf3Arrays , shape : tuple [int , ...]
322
+ ) -> JuliaObj :
307
323
return cls ._construct_compressed2d_jl_object (
308
324
arg = arg , shape = shape [::- 1 ], order = (2 , 1 )
309
325
)
@@ -331,7 +347,9 @@ def construct_csf_jl_object(
331
347
for size , indices , indptr in zip (shape [:- 1 ], indices_list , indptr_list ):
332
348
lvl = jl .SparseList (lvl , size , indptr , indices )
333
349
334
- jl_data = jl .swizzle (jl .Tensor (jl .Dense (lvl , shape [- 1 ])), * range (1 , len (shape ) + 1 ))
350
+ jl_data = jl .swizzle (
351
+ jl .Tensor (jl .Dense (lvl , shape [- 1 ])), * range (1 , len (shape ) + 1 )
352
+ )
335
353
return jl_data
336
354
337
355
@classmethod
@@ -377,7 +395,9 @@ def _slice_plus_one(s: slice, size: int) -> range:
377
395
378
396
if s .stop is not None :
379
397
stop_offset = 2 if step < 0 else 0
380
- stop = normalize_axis_index (s .stop , size ) + stop_offset if s .stop < size else size
398
+ stop = (
399
+ normalize_axis_index (s .stop , size ) + stop_offset if s .stop < size else size
400
+ )
381
401
else :
382
402
stop = stop_default
383
403
@@ -429,6 +449,7 @@ def _expand_ellipsis(key: tuple, shape: tuple[int, ...]) -> tuple:
429
449
key = new_key
430
450
return key
431
451
452
+
432
453
def _add_missing_dims (key : tuple , shape : tuple [int , ...]) -> tuple :
433
454
for i in range (len (key ), len (shape )):
434
455
key = key + (jl .range (start = 1 , stop = shape [i ]),)
0 commit comments