11import sys
22import os
33from types import SimpleNamespace
4+ from enum import IntEnum
45import shutil
56import subprocess
67import warnings
@@ -49,10 +50,31 @@ def default_algebra():
4950 raise RuntimeError ('No algebra backend available!' )
5051
5152
52- def constant (which , algebra ):
53+ def default_algebra_module ():
54+ """
55+ Get the default algebra module.
56+ Note: importlib.import_module is cached so we pay almost no penalty
57+ for repeated calls to this function.
58+ """
59+ return importlib .import_module (_ALGEBRA_MODULES [default_algebra ()])
60+
61+
62+ def constant (which , algebra = 'builtin' ):
63+ """
64+ Get a named constant from the extension module.
65+ Since constants are typically consistent across osqp algebras,
66+ we use the `builtin` algebra (always guaranteed to be available)
67+ by default.
68+ """
5369 m = importlib .import_module (_ALGEBRA_MODULES [algebra ])
5470 _constant = getattr (m , which , None )
5571
72+ if which in m .osqp_status_type .__members__ :
73+ warnings .warn (
74+ 'Direct access to osqp status values will be deprecated. Please use the SolverStatus enum instead.' ,
75+ PendingDeprecationWarning ,
76+ )
77+
5678 # If the constant was exported directly as an atomic type in the extension, use it;
5779 # Otherwise it's an enum out of which we can obtain the raw value
5880 if isinstance (_constant , (int , float , str )):
@@ -67,7 +89,57 @@ def constant(which, algebra):
6789 raise RuntimeError (f'Unknown constant { which } ' )
6890
6991
92+ def construct_enum (name , binding_enum_name ):
93+ """
94+ Dynamically construct an IntEnum from available enum members.
95+ For all values, see https://osqp.org/docs/interfaces/status_values.html
96+ """
97+ m = default_algebra_module ()
98+ binding_enum = getattr (m , binding_enum_name )
99+ return IntEnum (name , [(v .name , v .value ) for v in binding_enum .__members__ .values ()])
100+
101+
102+ SolverStatus = construct_enum ('SolverStatus' , 'osqp_status_type' )
103+ SolverError = construct_enum ('SolverError' , 'osqp_error_type' )
104+
105+
106+ class OSQPException (Exception ):
107+ """
108+ OSQPException is raised by the wrapper interface when it encounters an
109+ exception by the underlying OSQP solver.
110+ """
111+
112+ def __init__ (self , error_code = None ):
113+ if error_code :
114+ self .args = (error_code ,)
115+
116+ def __eq__ (self , error_code ):
117+ return len (self .args ) > 0 and self .args [0 ] == error_code
118+
119+
70120class OSQP :
121+
122+ """
123+ For OSQP bindings (see bindings.cpp.in) that throw `ValueError`s
124+ (through `throw py::value_error(...)`), we catch and re-raise them
125+ as `OSQPException`s, with the correct int value as args[0].
126+ """
127+
128+ @classmethod
129+ def raises_error (cls , fn , * args , ** kwargs ):
130+ try :
131+ return_value = fn (* args , ** kwargs )
132+ except ValueError as e :
133+ if e .args :
134+ error_code = None
135+ try :
136+ error_code = int (e .args [0 ])
137+ except ValueError :
138+ pass
139+ raise OSQPException (error_code )
140+ else :
141+ return return_value
142+
71143 def __init__ (self , * args , ** kwargs ):
72144 self .m = None
73145 self .n = None
@@ -253,7 +325,7 @@ def update_settings(self, **kwargs):
253325 raise ValueError (f'Unrecognized settings { list (kwargs .keys ())} ' )
254326
255327 if settings_changed and self ._solver is not None :
256- self ._solver .update_settings ( self .settings )
328+ self .raises_error ( self . _solver .update_settings , self .settings )
257329
258330 def update (self , ** kwargs ):
259331 # TODO: sanity-check on types/dimensions
@@ -310,7 +382,8 @@ def setup(self, P, q, A, l, u, **settings):
310382 self .ext .osqp_set_default_settings (self .settings )
311383 self .update_settings (** settings )
312384
313- self ._solver = self .ext .OSQPSolver (
385+ self ._solver = self .raises_error (
386+ self .ext .OSQPSolver ,
314387 P ,
315388 q ,
316389 A ,
@@ -327,16 +400,23 @@ def warm_start(self, x=None, y=None):
327400 # TODO: sanity checks on types/dimensions
328401 return self ._solver .warm_start (x , y )
329402
330- def solve (self , raise_error = False ):
403+ def solve (self , raise_error = None ):
404+ if raise_error is None :
405+ warnings .warn (
406+ 'The default value of raise_error will change to True in the future.' ,
407+ PendingDeprecationWarning ,
408+ )
409+ raise_error = False
410+
331411 self ._solver .solve ()
332412
333413 info = self ._solver .info
334- if info .status_val == self . constant ( ' OSQP_NON_CVX' ) :
414+ if info .status_val == SolverStatus . OSQP_NON_CVX :
335415 info .obj_val = np .nan
336416 # TODO: Handle primal/dual infeasibility
337417
338- if info .status_val != self . constant ( ' OSQP_SOLVED' ) and raise_error :
339- raise ValueError ( 'Problem not solved!' )
418+ if info .status_val != SolverStatus . OSQP_SOLVED and raise_error :
419+ raise OSQPException ( info . status_val )
340420
341421 # Create a Namespace of OSQPInfo keys and associated values
342422 _info = SimpleNamespace (** {k : getattr (info , k ) for k in info .__class__ .__dict__ if not k .startswith ('__' )})
@@ -445,7 +525,7 @@ def adjoint_derivative_compute(self, dx=None, dy=None):
445525 'Problem has not been solved. ' 'You cannot take derivatives. ' 'Please call the solve function.'
446526 )
447527
448- if results .info .status != 'solved' :
528+ if results .info .status_val != SolverStatus . OSQP_SOLVED :
449529 raise ValueError ('Problem has not been solved to optimality. ' 'You cannot take derivatives' )
450530
451531 if dy is None :
@@ -467,7 +547,7 @@ def adjoint_derivative_get_mat(self, as_dense=True, dP_as_triu=True):
467547 'Problem has not been solved. ' 'You cannot take derivatives. ' 'Please call the solve function.'
468548 )
469549
470- if results .info .status != 'solved' :
550+ if results .info .status_val != SolverStatus . OSQP_SOLVED :
471551 raise ValueError ('Problem has not been solved to optimality. ' 'You cannot take derivatives' )
472552
473553 P , _ = self ._derivative_cache ['P' ], self ._derivative_cache ['q' ]
@@ -501,7 +581,7 @@ def adjoint_derivative_get_vec(self):
501581 'Problem has not been solved. ' 'You cannot take derivatives. ' 'Please call the solve function.'
502582 )
503583
504- if results .info .status != 'solved' :
584+ if results .info .status_val != SolverStatus . OSQP_SOLVED :
505585 raise ValueError ('Problem has not been solved to optimality. ' 'You cannot take derivatives' )
506586
507587 dq = np .empty (self .n ).astype (self ._dtype )
0 commit comments