Skip to content

Commit 73356c6

Browse files
BUG: change the type of @ result from MatrixVariable to MatrixExpr (#1059)
* Add __matmul__ support to MatrixExpr Implements the __matmul__ method for MatrixExpr, enabling matrix multiplication using the @ operator and ensuring the result is returned as a MatrixExpr instance. * Add tests for matrix matmul return types Adds tests to verify that matrix multiplication returns MatrixExpr instead of MatrixVariable for various input shapes. * Use type() checks instead of isinstance in matmul tests Replaces isinstance checks with type() comparisons for MatrixExpr in matrix matmul return type tests to ensure exact type matching. * Fix matrix variable shapes in matmul test Corrects the shapes of matrix variables in test_matrix_matmul_return_type to ensure proper 2D matrix multiplication and type assertion. * Update CHANGELOG.md --------- Co-authored-by: João Dionísio <[email protected]>
1 parent 13d4a68 commit 73356c6

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
### Fixed
2525
- Raised an error when an expression is used when a variable is required
2626
- Fixed some compile warnings
27+
- Fixed the type of @ matrix operation result from MatrixVariable to MatrixExpr.
2728
### Changed
2829
- MatrixExpr.sum() now supports axis arguments and can return either a scalar or MatrixExpr, depending on the result dimensions.
2930
- AddMatrixCons() also accepts ExprCons.

src/pyscipopt/matrix.pxi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ class MatrixExpr(np.ndarray):
9898

9999
def __rsub__(self, other):
100100
return super().__rsub__(other).view(MatrixExpr)
101-
101+
102+
def __matmul__(self, other):
103+
return super().__matmul__(other).view(MatrixExpr)
104+
102105
class MatrixGenExpr(MatrixExpr):
103106
pass
104107

tests/test_matrix_variable.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,20 @@ def test_matrix_cons_indicator():
392392
assert m.getVal(is_equal).sum() == 2
393393
assert (m.getVal(x) == m.getVal(y)).all().all()
394394
assert (m.getVal(x) == np.array([[5, 5, 5], [5, 5, 5]])).all().all()
395+
396+
397+
def test_matrix_matmul_return_type():
398+
# test #1058, require returning type is MatrixExpr not MatrixVariable
399+
m = Model()
400+
401+
# test 1D @ 1D → 0D
402+
x = m.addMatrixVar(3)
403+
assert type(x @ x) is MatrixExpr
404+
405+
# test 1D @ 1D → 2D
406+
assert type(x[:, None] @ x[None, :]) is MatrixExpr
407+
408+
# test 2D @ 2D → 2D
409+
y = m.addMatrixVar((2, 3))
410+
z = m.addMatrixVar((3, 4))
411+
assert type(y @ z) is MatrixExpr

0 commit comments

Comments
 (0)