diff --git a/code/ndarray.c b/code/ndarray.c index b297b3f1..361d664d 100644 --- a/code/ndarray.c +++ b/code/ndarray.c @@ -1874,28 +1874,110 @@ mp_obj_t ndarray_unary_op(mp_unary_op_t op, mp_obj_t self_in) { #endif /* NDARRAY_HAS_UNARY_OPS */ #if NDARRAY_HAS_TRANSPOSE -mp_obj_t ndarray_transpose(mp_obj_t self_in) { - #if ULAB_MAX_DIMS == 1 - return self_in; - #endif - // TODO: check, what happens to the offset here, if we have a view +// We have to implement the T property separately, for the property can't take keyword arguments + +#if ULAB_MAX_DIMS == 1 +// isolating the one-dimensional case saves space, because the transpose is sort of meaningless +mp_obj_t ndarray_T(mp_obj_t self_in) { + return self_in; +} +#else +mp_obj_t ndarray_T(mp_obj_t self_in) { + // without argument, simply return a view with axes in reverse order ndarray_obj_t *self = MP_OBJ_TO_PTR(self_in); if(self->ndim == 1) { return self_in; } size_t *shape = m_new(size_t, self->ndim); int32_t *strides = m_new(int32_t, self->ndim); - for(uint8_t i=0; i < self->ndim; i++) { + for(uint8_t i = 0; i < self->ndim; i++) { shape[ULAB_MAX_DIMS - 1 - i] = self->shape[ULAB_MAX_DIMS - self->ndim + i]; strides[ULAB_MAX_DIMS - 1 - i] = self->strides[ULAB_MAX_DIMS - self->ndim + i]; } - // TODO: I am not sure ndarray_new_view is OK here... - // should be deep copy... ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0); return MP_OBJ_FROM_PTR(ndarray); } +#endif /* ULAB_MAX_DIMS == 1 */ + +MP_DEFINE_CONST_FUN_OBJ_1(ndarray_T_obj, ndarray_T); + +# if ULAB_MAX_DIMS == 1 +// again, nothing to do, if there is only one dimension, though, the arguments might still have to be parsed... +mp_obj_t ndarray_transpose(mp_obj_t self_in) { + return self_in; +} MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose); +#else +mp_obj_t ndarray_transpose(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { + static const mp_arg_t allowed_args[] = { + { MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, + { MP_QSTR_axes, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, + }; + + mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; + mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args); + + ndarray_obj_t *self = MP_OBJ_TO_PTR(args[0].u_obj); + + if(self->ndim == 1) { + return args[0].u_obj; + } + + size_t *shape = m_new(size_t, self->ndim); + int32_t *strides = m_new(int32_t, self->ndim); + uint8_t *order = m_new(uint8_t, self->ndim); + + mp_obj_t axes = args[1].u_obj; + + if(axes == mp_const_none) { + // simply swap the order of the axes + for(uint8_t i = 0; i < self->ndim; i++) { + order[i] = self->ndim - 1 - i; + } + } else { + if(!mp_obj_is_type(axes, &mp_type_tuple)) { + mp_raise_TypeError(MP_ERROR_TEXT("keyword argument must be tuple of integers")); + } + // start with the straight array, and then swap only those specified in the argument + for(uint8_t i = 0; i < self->ndim; i++) { + order[i] = i; + } + + mp_obj_tuple_t *axes_tuple = MP_OBJ_TO_PTR(axes); + + if(axes_tuple->len > self->ndim) { + mp_raise_ValueError(MP_ERROR_TEXT("too many axes specified")); + } + + for(uint8_t i = 0; i < axes_tuple->len; i++) { + int32_t ax = mp_obj_get_int(axes_tuple->items[i]); + if((ax >= self->ndim) || (ax < 0)) { + mp_raise_ValueError(MP_ERROR_TEXT("axis index out of bounds")); + } else { + order[i] = (uint8_t)ax; + // TODO: check that no two identical numbers appear in the tuple + for(uint8_t j = 0; j < i; j++) { + if(order[i] == order[j]) { + mp_raise_ValueError(MP_ERROR_TEXT("repeated indices")); + } + } + } + } + } + + uint8_t axis_offset = ULAB_MAX_DIMS - self->ndim; + for(uint8_t i = 0; i < self->ndim; i++) { + shape[axis_offset + i] = self->shape[axis_offset + order[i]]; + strides[axis_offset + i] = self->strides[axis_offset + order[i]]; + } + + ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0); + return MP_OBJ_FROM_PTR(ndarray); +} + +MP_DEFINE_CONST_FUN_OBJ_KW(ndarray_transpose_obj, 1, ndarray_transpose); +#endif /* ULAB_MAX_DIMS == 1 */ #endif /* NDARRAY_HAS_TRANSPOSE */ #if ULAB_MAX_DIMS > 1 diff --git a/code/ndarray.h b/code/ndarray.h index 8e81773d..af72b1ad 100644 --- a/code/ndarray.h +++ b/code/ndarray.h @@ -265,9 +265,16 @@ MP_DECLARE_CONST_FUN_OBJ_1(ndarray_tolist_obj); #endif #if NDARRAY_HAS_TRANSPOSE +mp_obj_t ndarray_T(mp_obj_t ); +MP_DECLARE_CONST_FUN_OBJ_1(ndarray_T_obj); +#if ULAB_MAX_DIMS == 1 mp_obj_t ndarray_transpose(mp_obj_t ); MP_DECLARE_CONST_FUN_OBJ_1(ndarray_transpose_obj); -#endif +#else +mp_obj_t ndarray_transpose(size_t , const mp_obj_t *, mp_map_t *); +MP_DECLARE_CONST_FUN_OBJ_KW(ndarray_transpose_obj); +#endif /* ULAB_MAX_DIMS == 1 */ +#endif /* NDARRAY_HAS_TRANSPOSE */ #if ULAB_NUMPY_HAS_NDINFO mp_obj_t ndarray_info(mp_obj_t ); diff --git a/code/ndarray_properties.c b/code/ndarray_properties.c index 6c048bda..8200445a 100644 --- a/code/ndarray_properties.c +++ b/code/ndarray_properties.c @@ -64,7 +64,7 @@ void ndarray_properties_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) { #endif #if NDARRAY_HAS_TRANSPOSE case MP_QSTR_T: - dest[0] = ndarray_transpose(self_in); + dest[0] = ndarray_T(self_in); break; #endif #if ULAB_SUPPORTS_COMPLEX