1- from typing import Callable , Sequence , Union , Tuple , List , Optional
1+ from contextlib import contextmanager
2+ from typing import Callable , Sequence , Union , Tuple , List , Optional , Iterator
23import os
34import time
45
56import numpy as np
67import matplotlib
78import matplotlib .pyplot as plt
89
9- from qcodes .dataset .measurements import Measurement
10+ from qcodes .dataset .measurements import Measurement , res_type , DataSaver
1011from qcodes .instrument .base import _BaseParameter
1112from qcodes .dataset .plotting import plot_by_id
1213from qcodes import config
1314
15+ ActionsT = Sequence [Callable [[], None ]]
16+
17+ ParamMeasT = Union [_BaseParameter , Callable [[], None ]]
18+
1419AxesTuple = Tuple [matplotlib .axes .Axes , matplotlib .colorbar .Colorbar ]
1520AxesTupleList = Tuple [List [matplotlib .axes .Axes ],
1621 List [Optional [matplotlib .colorbar .Colorbar ]]]
1722AxesTupleListWithRunId = Tuple [int , List [matplotlib .axes .Axes ],
1823 List [Optional [matplotlib .colorbar .Colorbar ]]]
19- number = Union [float , int ]
2024
2125
22- def do0d (* param_meas : Union [_BaseParameter , Callable [[], None ]],
23- write_period : Optional [float ] = None ,
24- do_plot : bool = True ) -> AxesTupleListWithRunId :
26+ def _process_params_meas (param_meas : ParamMeasT ) -> List [res_type ]:
27+ output = []
28+ for parameter in param_meas :
29+ if isinstance (parameter , _BaseParameter ):
30+ output .append ((parameter , parameter .get ()))
31+ elif callable (parameter ):
32+ parameter ()
33+ return output
34+
35+
36+ def _register_parameters (
37+ meas : Measurement ,
38+ param_meas : List [ParamMeasT ],
39+ setpoints : Optional [List [_BaseParameter ]] = None
40+ ) -> None :
41+ for parameter in param_meas :
42+ if isinstance (parameter , _BaseParameter ):
43+ meas .register_parameter (parameter ,
44+ setpoints = setpoints )
45+
46+
47+ def _register_actions (
48+ meas : Measurement ,
49+ enter_actions : ActionsT ,
50+ exit_actions : ActionsT
51+ ) -> None :
52+ for action in enter_actions :
53+ # this omits the possibility of passing
54+ # argument to enter and exit actions.
55+ # Do we want that?
56+ meas .add_before_run (action , ())
57+ for action in exit_actions :
58+ meas .add_after_run (action , ())
59+
60+
61+
62+ def _set_write_period (
63+ meas : Measurement ,
64+ write_period : Optional [float ] = None
65+ ) -> None :
66+ if write_period is not None :
67+ meas .write_period = write_period
68+
69+
70+ @contextmanager
71+ def _catch_keyboard_interrupts () -> Iterator [Callable [[], bool ]]:
72+ interrupted = False
73+ def has_been_interrupted ():
74+ nonlocal interrupted
75+ return interrupted
76+ try :
77+ yield has_been_interrupted
78+ except KeyboardInterrupt :
79+ interrupted = True
80+
81+
82+
83+ def do0d (
84+ * param_meas : ParamMeasT ,
85+ write_period : Optional [float ] = None ,
86+ do_plot : bool = True
87+ ) -> AxesTupleListWithRunId :
2588 """
2689 Perform a measurement of a single parameter. This is probably most
2790 useful for an ArrayParamter that already returns an array of data points
@@ -38,41 +101,27 @@ def do0d(*param_meas: Union[_BaseParameter, Callable[[], None]],
38101 The run_id of the DataSet created
39102 """
40103 meas = Measurement ()
41- if write_period is not None :
42- meas .write_period = write_period
43- output = []
44-
45- for parameter in param_meas :
46- meas .register_parameter (parameter )
47- output .append ([parameter , None ])
104+ _register_parameters (meas , param_meas )
105+ _set_write_period (meas , write_period )
48106
49107 with meas .run () as datasaver :
108+ datasaver .add_result (* _process_params_meas (param_meas ))
109+
110+ return _handle_plotting (datasaver , do_plot )
50111
51- for i , parameter in enumerate (param_meas ):
52- if isinstance (parameter , _BaseParameter ):
53- output [i ][1 ] = parameter .get ()
54- elif callable (parameter ):
55- parameter ()
56- datasaver .add_result (* output )
57- dataid = datasaver .run_id
58112
59- if do_plot is True :
60- ax , cbs = _save_image (datasaver )
61- else :
62- ax = None ,
63- cbs = None
64113
65- return dataid , ax , cbs
66114
67115
68- def do1d (param_set : _BaseParameter , start : number , stop : number ,
69- num_points : int , delay : number ,
70- * param_meas : Union [_BaseParameter , Callable [[], None ]],
71- enter_actions : Sequence [Callable [[], None ]] = (),
72- exit_actions : Sequence [Callable [[], None ]] = (),
73- write_period : Optional [float ] = None ,
74- do_plot : bool = True ) \
75- -> AxesTupleListWithRunId :
116+ def do1d (
117+ param_set : _BaseParameter , start : float , stop : float ,
118+ num_points : int , delay : float ,
119+ * param_meas : ParamMeasT ,
120+ enter_actions : ActionsT = (),
121+ exit_actions : ActionsT = (),
122+ write_period : Optional [float ] = None ,
123+ do_plot : bool = True
124+ ) -> AxesTupleListWithRunId :
76125 """
77126 Perform a 1D scan of ``param_set`` from ``start`` to ``stop`` in
78127 ``num_points`` measuring param_meas at each step. In case param_meas is
@@ -99,72 +148,38 @@ def do1d(param_set: _BaseParameter, start: number, stop: number,
99148 The run_id of the DataSet created
100149 """
101150 meas = Measurement ()
102- if write_period is not None :
103- meas .write_period = write_period
104- meas .register_parameter (
105- param_set ) # register the first independent parameter
106- output = []
151+ _register_parameters (meas , (param_set ,))
152+ _register_parameters (meas , param_meas , setpoints = (param_set ,))
153+ _set_write_period (meas , write_period )
154+ _register_actions (meas , enter_actions , exit_actions )
107155 param_set .post_delay = delay
108- interrupted = False
109-
110- for action in enter_actions :
111- # this omits the posibility of passing
112- # argument to enter and exit actions.
113- # Do we want that?
114- meas .add_before_run (action , ())
115- for action in exit_actions :
116- meas .add_after_run (action , ())
117156
118157 # do1D enforces a simple relationship between measured parameters
119158 # and set parameters. For anything more complicated this should be
120159 # reimplemented from scratch
121- for parameter in param_meas :
122- if isinstance (parameter , _BaseParameter ):
123- meas .register_parameter (parameter , setpoints = (param_set ,))
124- output .append ([parameter , None ])
125-
126- try :
127- with meas .run () as datasaver :
128-
129- for set_point in np .linspace (start , stop , num_points ):
130- param_set .set (set_point )
131- output = []
132- for parameter in param_meas :
133- if isinstance (parameter , _BaseParameter ):
134- output .append ((parameter , parameter .get ()))
135- elif callable (parameter ):
136- parameter ()
137- datasaver .add_result ((param_set , set_point ),
138- * output )
139- except KeyboardInterrupt :
140- interrupted = True
141-
142- dataid = datasaver .run_id # convenient to have for plotting
143-
144- if do_plot is True :
145- ax , cbs = _save_image (datasaver )
146- else :
147- ax = None ,
148- cbs = None
149-
150- if interrupted :
151- raise KeyboardInterrupt
152- return dataid , ax , cbs
153-
154-
155- def do2d (param_set1 : _BaseParameter , start1 : number , stop1 : number ,
156- num_points1 : int , delay1 : number ,
157- param_set2 : _BaseParameter , start2 : number , stop2 : number ,
158- num_points2 : int , delay2 : number ,
159- * param_meas : Union [_BaseParameter , Callable [[], None ]],
160- set_before_sweep : Optional [bool ] = False ,
161- enter_actions : Sequence [Callable [[], None ]] = (),
162- exit_actions : Sequence [Callable [[], None ]] = (),
163- before_inner_actions : Sequence [Callable [[], None ]] = (),
164- after_inner_actions : Sequence [Callable [[], None ]] = (),
165- write_period : Optional [float ] = None ,
166- flush_columns : bool = False ,
167- do_plot : bool = True ) -> AxesTupleListWithRunId :
160+ with _catch_keyboard_interrupts () as interrupted , meas .run () as datasaver :
161+ for set_point in np .linspace (start , stop , num_points ):
162+ param_set .set (set_point )
163+ datasaver .add_result ((param_set , set_point ),
164+ * _process_params_meas (param_meas ))
165+ return _handle_plotting (datasaver , do_plot , interrupted ())
166+
167+
168+ def do2d (
169+ param_set1 : _BaseParameter , start1 : float , stop1 : float ,
170+ num_points1 : int , delay1 : float ,
171+ param_set2 : _BaseParameter , start2 : float , stop2 : float ,
172+ num_points2 : int , delay2 : float ,
173+ * param_meas : ParamMeasT ,
174+ set_before_sweep : Optional [bool ] = False ,
175+ enter_actions : ActionsT = (),
176+ exit_actions : ActionsT = (),
177+ before_inner_actions : ActionsT = (),
178+ after_inner_actions : ActionsT = (),
179+ write_period : Optional [float ] = None ,
180+ flush_columns : bool = False ,
181+ do_plot : bool = True
182+ ) -> AxesTupleListWithRunId :
168183
169184 """
170185 Perform a 1D scan of ``param_set1`` from ``start1`` to ``stop1`` in
@@ -202,29 +217,16 @@ def do2d(param_set1: _BaseParameter, start1: number, stop1: number,
202217 """
203218
204219 meas = Measurement ()
205- if write_period is not None :
206- meas .write_period = write_period
207- meas .register_parameter (param_set1 )
220+ _register_parameters (meas , (param_set1 , param_set2 ))
221+ _register_parameters (meas , param_meas , setpoints = (param_set1 , param_set2 ))
222+ _set_write_period (meas , write_period )
223+ _register_actions (meas , enter_actions , exit_actions )
224+
208225 param_set1 .post_delay = delay1
209- meas .register_parameter (param_set2 )
210226 param_set2 .post_delay = delay2
211- interrupted = False
212- for action in enter_actions :
213- # this omits the possibility of passing
214- # argument to enter and exit actions.
215- # Do we want that?
216- meas .add_before_run (action , ())
217-
218- for action in exit_actions :
219- meas .add_after_run (action , ())
220227
221- for parameter in param_meas :
222- if isinstance (parameter , _BaseParameter ):
223- meas .register_parameter (parameter ,
224- setpoints = (param_set1 , param_set2 ))
225- try :
226- with meas .run () as datasaver :
227- for set_point1 in np .linspace (start1 , stop1 , num_points1 ):
228+ with _catch_keyboard_interrupts () as interrupted , meas .run () as datasaver :
229+ for set_point1 in np .linspace (start1 , stop1 , num_points1 ):
228230 if set_before_sweep :
229231 param_set2 .set (start2 )
230232
@@ -237,67 +239,64 @@ def do2d(param_set1: _BaseParameter, start1: number, stop1: number,
237239 pass
238240 else :
239241 param_set2 .set (set_point2 )
240- output = []
241- for parameter in param_meas :
242- if isinstance (parameter , _BaseParameter ):
243- output .append ((parameter , parameter .get ()))
244- elif callable (parameter ):
245- parameter ()
242+
246243 datasaver .add_result ((param_set1 , set_point1 ),
247244 (param_set2 , set_point2 ),
248- * output )
245+ * _process_params_meas ( param_meas ) )
249246 for action in after_inner_actions :
250247 action ()
251248 if flush_columns :
252249 datasaver .flush_data_to_database ()
253- except KeyboardInterrupt :
254- interrupted = True
255250
256- dataid = datasaver . run_id
251+ return _handle_plotting ( datasaver , do_plot , interrupted ())
257252
258- if do_plot is True :
259- ax , cbs = _save_image (datasaver )
260- else :
261- ax = None ,
262- cbs = None
263- if interrupted :
264- raise KeyboardInterrupt
265253
266- return dataid , ax , cbs
267254
268255
269- def _save_image (datasaver ) -> AxesTupleList :
256+ def _handle_plotting (
257+ datasaver : DataSaver ,
258+ do_plot : bool = True ,
259+ interrupted : bool = False
260+ ) -> AxesTupleList :
270261 """
271262 Save the plots created by datasaver as pdf and png
272263
273264 Args:
274265 datasaver: a measurement datasaver that contains a dataset to be saved
275266 as plot.
267+ :param do_plot:
276268
277269 """
278- plt .ioff ()
279270 dataid = datasaver .run_id
271+ if do_plot == True :
272+ res = _create_plots (datasaver )
273+ else :
274+ res = dataid , None , None
275+
276+ if interrupted :
277+ raise KeyboardInterrupt
278+
279+ return res
280+
281+
282+ def _create_plots (datasaver : DataSaver ) -> AxesTupleList :
283+ dataid = datasaver .run_id
284+ plt .ioff ()
280285 start = time .time ()
281286 axes , cbs = plot_by_id (dataid )
282287 stop = time .time ()
283- print (f"plot by id took { stop - start } " )
284-
288+ print (f"plot by id took { stop - start } " )
285289 mainfolder = config .user .mainfolder
286290 experiment_name = datasaver ._dataset .exp_name
287291 sample_name = datasaver ._dataset .sample_name
288-
289292 storage_dir = os .path .join (mainfolder , experiment_name , sample_name )
290293 os .makedirs (storage_dir , exist_ok = True )
291-
292294 png_dir = os .path .join (storage_dir , 'png' )
293295 pdf_dif = os .path .join (storage_dir , 'pdf' )
294-
295296 os .makedirs (png_dir , exist_ok = True )
296297 os .makedirs (pdf_dif , exist_ok = True )
297-
298298 save_pdf = True
299299 save_png = True
300-
301300 for i , ax in enumerate (axes ):
302301 if save_pdf :
303302 full_path = os .path .join (pdf_dif , f'{ dataid } _{ i } .pdf' )
@@ -306,4 +305,5 @@ def _save_image(datasaver) -> AxesTupleList:
306305 full_path = os .path .join (png_dir , f'{ dataid } _{ i } .png' )
307306 ax .figure .savefig (full_path , dpi = 500 )
308307 plt .ion ()
309- return axes , cbs
308+ res = dataid , axes , cbs
309+ return res
0 commit comments