Skip to content

Commit cb96522

Browse files
authored
Introduced an osqp.SolverStatus enum (#172)
* introduced an osqp.SolverStatus enum * support for SolverError values from osqp; raising OSQPExceptions when we should
1 parent 094acea commit cb96522

File tree

8 files changed

+176
-37
lines changed

8 files changed

+176
-37
lines changed

examples/basic_usage.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@
1717
# Setup workspace and change alpha parameter
1818
prob.setup(P, q, A, l, u, alpha=1.0)
1919

20+
# Settings can be changed using .update_settings()
21+
prob.update_settings(polishing=1)
22+
2023
# Solve problem
21-
res = prob.solve()
24+
res = prob.solve(raise_error=True)
25+
26+
# Check solver status
27+
# For all values, see https://osqp.org/docs/interfaces/status_values.html
28+
assert res.info.status_val == osqp.SolverStatus.OSQP_SOLVED
2229

2330
print('Status:', res.info.status)
2431
print('Objective value:', res.info.obj_val)

examples/exception_handling.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import osqp
2+
import numpy as np
3+
from scipy import sparse
4+
5+
6+
"""
7+
8+
`osqp.OSQPException`s might be raised during `.setup()`, `.update_settings()`,
9+
or `.solve()`. This example demonstrates how to catch an `osqp.OSQPException`
10+
raised during `.setup()`, and how to compare it to a specific `osqp.SolverError`.
11+
12+
Exceptions other than `osqp.OSQPException` might also be raised, but these
13+
are typically errors in using the wrapper, and are not raised by the underlying
14+
`osqp` library itself.
15+
16+
"""
17+
18+
if __name__ == '__main__':
19+
20+
P = sparse.triu([[2.0, 5.0], [5.0, 1.0]], format='csc')
21+
q = np.array([3.0, 4.0])
22+
A = sparse.csc_matrix([[-1.0, 0.0], [0.0, -1.0], [-1.0, 3.0], [2.0, 5.0], [3.0, 4]])
23+
l = -np.inf * np.ones(A.shape[0])
24+
u = np.array([0.0, 0.0, -15.0, 100.0, 80.0])
25+
26+
prob = osqp.OSQP()
27+
28+
try:
29+
prob.setup(P, q, A, l, u)
30+
except osqp.OSQPException as e:
31+
# Our problem is non-convex, so we get a osqp.OSQPException
32+
# during .setup()
33+
assert e == osqp.SolverError.OSQP_NONCVX_ERROR

src/bindings.cpp.in

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,7 @@ PyOSQPSolver::PyOSQPSolver(
151151

152152
OSQPInt status = osqp_setup(&this->_solver, &this->_P.getcsc(), (OSQPFloat *)this->_q.data(), &this->_A.getcsc(), (OSQPFloat *)this->_l.data(), (OSQPFloat *)this->_u.data(), m, n, settings);
153153
if (status) {
154-
std::string message = "Setup Error (Error Code " + std::to_string(status) + ")";
155-
throw py::value_error(message);
154+
throw py::value_error(std::to_string(status));
156155
}
157156
}
158157

@@ -199,7 +198,12 @@ OSQPInt PyOSQPSolver::solve() {
199198
}
200199

201200
OSQPInt PyOSQPSolver::update_settings(const OSQPSettings& new_settings) {
202-
return osqp_update_settings(this->_solver, &new_settings);
201+
OSQPInt status = osqp_update_settings(this->_solver, &new_settings);
202+
if (status) {
203+
throw py::value_error(std::to_string(status));
204+
} else {
205+
return status;
206+
}
203207
}
204208

205209
OSQPInt PyOSQPSolver::update_rho(OSQPFloat rho_new) {
@@ -353,6 +357,20 @@ PYBIND11_MODULE(@OSQP_EXT_MODULE_NAME@, m) {
353357
.value("OSQP_UNSOLVED", OSQP_UNSOLVED)
354358
.export_values();
355359

360+
// Solver Errors
361+
py::enum_<osqp_error_type>(m, "osqp_error_type", py::module_local())
362+
.value("OSQP_NO_ERROR", OSQP_NO_ERROR)
363+
.value("OSQP_DATA_VALIDATION_ERROR", OSQP_DATA_VALIDATION_ERROR)
364+
.value("OSQP_SETTINGS_VALIDATION_ERROR", OSQP_SETTINGS_VALIDATION_ERROR)
365+
.value("OSQP_LINSYS_SOLVER_INIT_ERROR", OSQP_LINSYS_SOLVER_INIT_ERROR)
366+
.value("OSQP_NONCVX_ERROR", OSQP_NONCVX_ERROR)
367+
.value("OSQP_MEM_ALLOC_ERROR", OSQP_MEM_ALLOC_ERROR)
368+
.value("OSQP_WORKSPACE_NOT_INIT_ERROR", OSQP_WORKSPACE_NOT_INIT_ERROR)
369+
.value("OSQP_ALGEBRA_LOAD_ERROR", OSQP_ALGEBRA_LOAD_ERROR)
370+
.value("OSQP_CODEGEN_DEFINES_ERROR", OSQP_CODEGEN_DEFINES_ERROR)
371+
.value("OSQP_DATA_NOT_INITIALIZED", OSQP_DATA_NOT_INITIALIZED)
372+
.value("OSQP_FUNC_NOT_IMPLEMENTED", OSQP_FUNC_NOT_IMPLEMENTED);
373+
356374
// Preconditioner Type
357375
py::enum_<osqp_precond_type>(m, "osqp_precond_type", py::module_local())
358376
.value("OSQP_NO_PRECONDITIONER", OSQP_NO_PRECONDITIONER)

src/osqp/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
# and is not in version control.
33
from osqp._version import version as __version__ # noqa: F401
44
from osqp.interface import ( # noqa: F401
5+
OSQPException,
56
OSQP,
67
constant,
78
algebra_available,
89
algebras_available,
910
default_algebra,
11+
SolverStatus,
12+
SolverError,
1013
)

src/osqp/interface.py

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import os
33
from types import SimpleNamespace
4+
from enum import IntEnum
45
import shutil
56
import subprocess
67
import warnings
@@ -49,10 +50,31 @@ def default_algebra():
4950
raise RuntimeError('No algebra backend available!')
5051

5152

52-
def constant(which, algebra):
53+
def default_algebra_module():
54+
"""
55+
Get the default algebra module.
56+
Note: importlib.import_module is cached so we pay almost no penalty
57+
for repeated calls to this function.
58+
"""
59+
return importlib.import_module(_ALGEBRA_MODULES[default_algebra()])
60+
61+
62+
def constant(which, algebra='builtin'):
63+
"""
64+
Get a named constant from the extension module.
65+
Since constants are typically consistent across osqp algebras,
66+
we use the `builtin` algebra (always guaranteed to be available)
67+
by default.
68+
"""
5369
m = importlib.import_module(_ALGEBRA_MODULES[algebra])
5470
_constant = getattr(m, which, None)
5571

72+
if which in m.osqp_status_type.__members__:
73+
warnings.warn(
74+
'Direct access to osqp status values will be deprecated. Please use the SolverStatus enum instead.',
75+
PendingDeprecationWarning,
76+
)
77+
5678
# If the constant was exported directly as an atomic type in the extension, use it;
5779
# Otherwise it's an enum out of which we can obtain the raw value
5880
if isinstance(_constant, (int, float, str)):
@@ -67,7 +89,57 @@ def constant(which, algebra):
6789
raise RuntimeError(f'Unknown constant {which}')
6890

6991

92+
def construct_enum(name, binding_enum_name):
93+
"""
94+
Dynamically construct an IntEnum from available enum members.
95+
For all values, see https://osqp.org/docs/interfaces/status_values.html
96+
"""
97+
m = default_algebra_module()
98+
binding_enum = getattr(m, binding_enum_name)
99+
return IntEnum(name, [(v.name, v.value) for v in binding_enum.__members__.values()])
100+
101+
102+
SolverStatus = construct_enum('SolverStatus', 'osqp_status_type')
103+
SolverError = construct_enum('SolverError', 'osqp_error_type')
104+
105+
106+
class OSQPException(Exception):
107+
"""
108+
OSQPException is raised by the wrapper interface when it encounters an
109+
exception by the underlying OSQP solver.
110+
"""
111+
112+
def __init__(self, error_code=None):
113+
if error_code:
114+
self.args = (error_code,)
115+
116+
def __eq__(self, error_code):
117+
return len(self.args) > 0 and self.args[0] == error_code
118+
119+
70120
class OSQP:
121+
122+
"""
123+
For OSQP bindings (see bindings.cpp.in) that throw `ValueError`s
124+
(through `throw py::value_error(...)`), we catch and re-raise them
125+
as `OSQPException`s, with the correct int value as args[0].
126+
"""
127+
128+
@classmethod
129+
def raises_error(cls, fn, *args, **kwargs):
130+
try:
131+
return_value = fn(*args, **kwargs)
132+
except ValueError as e:
133+
if e.args:
134+
error_code = None
135+
try:
136+
error_code = int(e.args[0])
137+
except ValueError:
138+
pass
139+
raise OSQPException(error_code)
140+
else:
141+
return return_value
142+
71143
def __init__(self, *args, **kwargs):
72144
self.m = None
73145
self.n = None
@@ -253,7 +325,7 @@ def update_settings(self, **kwargs):
253325
raise ValueError(f'Unrecognized settings {list(kwargs.keys())}')
254326

255327
if settings_changed and self._solver is not None:
256-
self._solver.update_settings(self.settings)
328+
self.raises_error(self._solver.update_settings, self.settings)
257329

258330
def update(self, **kwargs):
259331
# TODO: sanity-check on types/dimensions
@@ -310,7 +382,8 @@ def setup(self, P, q, A, l, u, **settings):
310382
self.ext.osqp_set_default_settings(self.settings)
311383
self.update_settings(**settings)
312384

313-
self._solver = self.ext.OSQPSolver(
385+
self._solver = self.raises_error(
386+
self.ext.OSQPSolver,
314387
P,
315388
q,
316389
A,
@@ -327,16 +400,23 @@ def warm_start(self, x=None, y=None):
327400
# TODO: sanity checks on types/dimensions
328401
return self._solver.warm_start(x, y)
329402

330-
def solve(self, raise_error=False):
403+
def solve(self, raise_error=None):
404+
if raise_error is None:
405+
warnings.warn(
406+
'The default value of raise_error will change to True in the future.',
407+
PendingDeprecationWarning,
408+
)
409+
raise_error = False
410+
331411
self._solver.solve()
332412

333413
info = self._solver.info
334-
if info.status_val == self.constant('OSQP_NON_CVX'):
414+
if info.status_val == SolverStatus.OSQP_NON_CVX:
335415
info.obj_val = np.nan
336416
# TODO: Handle primal/dual infeasibility
337417

338-
if info.status_val != self.constant('OSQP_SOLVED') and raise_error:
339-
raise ValueError('Problem not solved!')
418+
if info.status_val != SolverStatus.OSQP_SOLVED and raise_error:
419+
raise OSQPException(info.status_val)
340420

341421
# Create a Namespace of OSQPInfo keys and associated values
342422
_info = SimpleNamespace(**{k: getattr(info, k) for k in info.__class__.__dict__ if not k.startswith('__')})
@@ -445,7 +525,7 @@ def adjoint_derivative_compute(self, dx=None, dy=None):
445525
'Problem has not been solved. ' 'You cannot take derivatives. ' 'Please call the solve function.'
446526
)
447527

448-
if results.info.status != 'solved':
528+
if results.info.status_val != SolverStatus.OSQP_SOLVED:
449529
raise ValueError('Problem has not been solved to optimality. ' 'You cannot take derivatives')
450530

451531
if dy is None:
@@ -467,7 +547,7 @@ def adjoint_derivative_get_mat(self, as_dense=True, dP_as_triu=True):
467547
'Problem has not been solved. ' 'You cannot take derivatives. ' 'Please call the solve function.'
468548
)
469549

470-
if results.info.status != 'solved':
550+
if results.info.status_val != SolverStatus.OSQP_SOLVED:
471551
raise ValueError('Problem has not been solved to optimality. ' 'You cannot take derivatives')
472552

473553
P, _ = self._derivative_cache['P'], self._derivative_cache['q']
@@ -501,7 +581,7 @@ def adjoint_derivative_get_vec(self):
501581
'Problem has not been solved. ' 'You cannot take derivatives. ' 'Please call the solve function.'
502582
)
503583

504-
if results.info.status != 'solved':
584+
if results.info.status_val != SolverStatus.OSQP_SOLVED:
505585
raise ValueError('Problem has not been solved to optimality. ' 'You cannot take derivatives')
506586

507587
dq = np.empty(self.n).astype(self._dtype)

src/osqp/nn/torch.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,7 @@ def _get_update_flag(n_batch: int) -> bool:
122122
"""
123123
num_solvers = len(solvers)
124124
if num_solvers not in (0, n_batch):
125-
raise RuntimeError(
126-
f'Invalid number of solvers: expected 0 or {n_batch},' f' but got {num_solvers}.'
127-
)
125+
raise RuntimeError(f'Invalid number of solvers: expected 0 or {n_batch}, but got {num_solvers}.')
128126
return num_solvers == n_batch
129127

130128
def _inner_solve(i, update_flag, q, l, u, P_val, P_idx, A_val, A_idx, solver_type, eps_abs, eps_rel):
@@ -157,8 +155,8 @@ def _inner_solve(i, update_flag, q, l, u, P_val, P_idx, A_val, A_idx, solver_typ
157155
eps_rel=eps_rel,
158156
)
159157
result = solver.solve()
160-
status = result.info.status
161-
if status != 'solved':
158+
status = result.info.status_val
159+
if status != osqp.SolverStatus.OSQP_SOLVED:
162160
# TODO: We can replace this with something calmer and
163161
# add some more options around potentially ignoring this.
164162
raise RuntimeError(f'Unable to solve QP, status: {status}')

0 commit comments

Comments
 (0)