11from  __future__ import  annotations 
22
33from  copy  import  copy 
4- from  typing  import  Any , Callable , Iterable 
4+ from  typing  import  Any , Callable , Iterable ,  Tuple 
55
66import  cloudpickle 
7- import  numpy  as  np 
87from  sortedcontainers  import  SortedDict , SortedSet 
98
109from  adaptive .learner .base_learner  import  BaseLearner 
10+ from  adaptive .types  import  Int 
1111from  adaptive .utils  import  assign_defaults , partial_function_from_dataframe 
1212
1313try :
1818except  ModuleNotFoundError :
1919    with_pandas  =  False 
2020
21+ try :
22+     from  typing  import  TypeAlias 
23+ except  ImportError :
24+     from  typing_extensions  import  TypeAlias 
25+ 
26+ 
27+ PointType : TypeAlias  =  Tuple [Int , Any ]
28+ 
2129
2230class  _IgnoreFirstArgument :
2331    """Remove the first argument from the call signature. 
@@ -32,9 +40,7 @@ class _IgnoreFirstArgument:
3240    def  __init__ (self , function : Callable ) ->  None :
3341        self .function  =  function   # type: ignore 
3442
35-     def  __call__ (
36-         self , index_point : tuple [int , float  |  np .ndarray ], * args , ** kwargs 
37-     ) ->  float :
43+     def  __call__ (self , index_point : PointType , * args , ** kwargs ):
3844        index , point  =  index_point 
3945        return  self .function (point , * args , ** kwargs )
4046
@@ -85,7 +91,9 @@ def new(self) -> SequenceLearner:
8591        """Return a new `~adaptive.SequenceLearner` without the data.""" 
8692        return  SequenceLearner (self ._original_function , self .sequence )
8793
88-     def  ask (self , n : int , tell_pending : bool  =  True ) ->  tuple [Any , list [float ]]:
94+     def  ask (
95+         self , n : int , tell_pending : bool  =  True 
96+     ) ->  tuple [list [PointType ], list [float ]]:
8997        indices  =  []
9098        points  =  []
9199        loss_improvements  =  []
@@ -105,31 +113,31 @@ def ask(self, n: int, tell_pending: bool = True) -> tuple[Any, list[float]]:
105113
106114    def  loss (self , real : bool  =  True ) ->  float :
107115        if  not  (self ._to_do_indices  or  self .pending_points ):
108-             return  0 
116+             return  0.0  
109117        else :
110118            npoints  =  self .npoints  +  (0  if  real  else  len (self .pending_points ))
111119            return  (self ._ntotal  -  npoints ) /  self ._ntotal 
112120
113-     def  remove_unfinished (self ):
121+     def  remove_unfinished (self )  ->   None :
114122        for  i  in  self .pending_points :
115123            self ._to_do_indices .add (i )
116124        self .pending_points  =  set ()
117125
118-     def  tell (self , point : tuple [ int ,  Any ] , value : Any ) ->  None :
126+     def  tell (self , point : PointType , value : Any ) ->  None :
119127        index , point  =  point 
120128        self .data [index ] =  value 
121129        self .pending_points .discard (index )
122130        self ._to_do_indices .discard (index )
123131
124-     def  tell_pending (self , point : Any ) ->  None :
132+     def  tell_pending (self , point : PointType ) ->  None :
125133        index , point  =  point 
126134        self .pending_points .add (index )
127135        self ._to_do_indices .discard (index )
128136
129-     def  done (self ):
137+     def  done (self )  ->   bool :
130138        return  not  self ._to_do_indices  and  not  self .pending_points 
131139
132-     def  result (self ):
140+     def  result (self )  ->   list [ Any ] :
133141        """Get the function values in the same order as ``sequence``.""" 
134142        if  not  self .done ():
135143            raise  Exception ("Learner is not yet complete." )
@@ -217,16 +225,18 @@ def load_dataframe(
217225        y_name : str, optional 
218226            The ``y_name`` used in ``to_dataframe``, by default "y" 
219227        """ 
220-         self .tell_many (df [[index_name , x_name ]].values , df [y_name ].values )
228+         indices  =  df [index_name ].values 
229+         xs  =  df [x_name ].values 
230+         self .tell_many (zip (indices , xs ), df [y_name ].values )
221231        if  with_default_function_args :
222232            self .function  =  partial_function_from_dataframe (
223233                self ._original_function , df , function_prefix 
224234            )
225235
226-     def  _get_data (self ) ->  SortedDict :
236+     def  _get_data (self ) ->  dict [ int ,  Any ] :
227237        return  self .data 
228238
229-     def  _set_data (self , data : SortedDict ) ->  None :
239+     def  _set_data (self , data : dict [ int ,  Any ] ) ->  None :
230240        if  data :
231241            indices , values  =  zip (* data .items ())
232242            # the points aren't used by tell, so we can safely pass None 
0 commit comments