Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 90 additions & 8 deletions code/ndarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -1874,28 +1874,110 @@ mp_obj_t ndarray_unary_op(mp_unary_op_t op, mp_obj_t self_in) {
#endif /* NDARRAY_HAS_UNARY_OPS */

#if NDARRAY_HAS_TRANSPOSE
mp_obj_t ndarray_transpose(mp_obj_t self_in) {
#if ULAB_MAX_DIMS == 1
return self_in;
#endif
// TODO: check, what happens to the offset here, if we have a view
// We have to implement the T property separately, for the property can't take keyword arguments

#if ULAB_MAX_DIMS == 1
// isolating the one-dimensional case saves space, because the transpose is sort of meaningless
mp_obj_t ndarray_T(mp_obj_t self_in) {
return self_in;
}
#else
mp_obj_t ndarray_T(mp_obj_t self_in) {
// without argument, simply return a view with axes in reverse order
ndarray_obj_t *self = MP_OBJ_TO_PTR(self_in);
if(self->ndim == 1) {
return self_in;
}
size_t *shape = m_new(size_t, self->ndim);
int32_t *strides = m_new(int32_t, self->ndim);
for(uint8_t i=0; i < self->ndim; i++) {
for(uint8_t i = 0; i < self->ndim; i++) {
shape[ULAB_MAX_DIMS - 1 - i] = self->shape[ULAB_MAX_DIMS - self->ndim + i];
strides[ULAB_MAX_DIMS - 1 - i] = self->strides[ULAB_MAX_DIMS - self->ndim + i];
}
// TODO: I am not sure ndarray_new_view is OK here...
// should be deep copy...
ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0);
return MP_OBJ_FROM_PTR(ndarray);
}
#endif /* ULAB_MAX_DIMS == 1 */

MP_DEFINE_CONST_FUN_OBJ_1(ndarray_T_obj, ndarray_T);

# if ULAB_MAX_DIMS == 1
// again, nothing to do, if there is only one dimension, though, the arguments might still have to be parsed...
mp_obj_t ndarray_transpose(mp_obj_t self_in) {
return self_in;
}

MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose);
#else
mp_obj_t ndarray_transpose(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_axes, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
};

mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);

ndarray_obj_t *self = MP_OBJ_TO_PTR(args[0].u_obj);

if(self->ndim == 1) {
return args[0].u_obj;
}

size_t *shape = m_new(size_t, self->ndim);
int32_t *strides = m_new(int32_t, self->ndim);
uint8_t *order = m_new(uint8_t, self->ndim);

mp_obj_t axes = args[1].u_obj;

if(axes == mp_const_none) {
// simply swap the order of the axes
for(uint8_t i = 0; i < self->ndim; i++) {
order[i] = self->ndim - 1 - i;
}
} else {
if(!mp_obj_is_type(axes, &mp_type_tuple)) {
mp_raise_TypeError(MP_ERROR_TEXT("keyword argument must be tuple of integers"));
}
// start with the straight array, and then swap only those specified in the argument
for(uint8_t i = 0; i < self->ndim; i++) {
order[i] = i;
}

mp_obj_tuple_t *axes_tuple = MP_OBJ_TO_PTR(axes);

if(axes_tuple->len > self->ndim) {
mp_raise_ValueError(MP_ERROR_TEXT("too many axes specified"));
}

for(uint8_t i = 0; i < axes_tuple->len; i++) {
int32_t ax = mp_obj_get_int(axes_tuple->items[i]);
if((ax >= self->ndim) || (ax < 0)) {
mp_raise_ValueError(MP_ERROR_TEXT("axis index out of bounds"));
} else {
order[i] = (uint8_t)ax;
// TODO: check that no two identical numbers appear in the tuple
for(uint8_t j = 0; j < i; j++) {
if(order[i] == order[j]) {
mp_raise_ValueError(MP_ERROR_TEXT("repeated indices"));
}
}
}
}
}

uint8_t axis_offset = ULAB_MAX_DIMS - self->ndim;
for(uint8_t i = 0; i < self->ndim; i++) {
shape[axis_offset + i] = self->shape[axis_offset + order[i]];
strides[axis_offset + i] = self->strides[axis_offset + order[i]];
}

ndarray_obj_t *ndarray = ndarray_new_view(self, self->ndim, shape, strides, 0);
return MP_OBJ_FROM_PTR(ndarray);
}

MP_DEFINE_CONST_FUN_OBJ_KW(ndarray_transpose_obj, 1, ndarray_transpose);
#endif /* ULAB_MAX_DIMS == 1 */
#endif /* NDARRAY_HAS_TRANSPOSE */

#if ULAB_MAX_DIMS > 1
Expand Down
9 changes: 8 additions & 1 deletion code/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,16 @@ MP_DECLARE_CONST_FUN_OBJ_1(ndarray_tolist_obj);
#endif

#if NDARRAY_HAS_TRANSPOSE
mp_obj_t ndarray_T(mp_obj_t );
MP_DECLARE_CONST_FUN_OBJ_1(ndarray_T_obj);
#if ULAB_MAX_DIMS == 1
mp_obj_t ndarray_transpose(mp_obj_t );
MP_DECLARE_CONST_FUN_OBJ_1(ndarray_transpose_obj);
#endif
#else
mp_obj_t ndarray_transpose(size_t , const mp_obj_t *, mp_map_t *);
MP_DECLARE_CONST_FUN_OBJ_KW(ndarray_transpose_obj);
#endif /* ULAB_MAX_DIMS == 1 */
#endif /* NDARRAY_HAS_TRANSPOSE */

#if ULAB_NUMPY_HAS_NDINFO
mp_obj_t ndarray_info(mp_obj_t );
Expand Down
2 changes: 1 addition & 1 deletion code/ndarray_properties.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void ndarray_properties_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) {
#endif
#if NDARRAY_HAS_TRANSPOSE
case MP_QSTR_T:
dest[0] = ndarray_transpose(self_in);
dest[0] = ndarray_T(self_in);
break;
#endif
#if ULAB_SUPPORTS_COMPLEX
Expand Down
Loading