Skip to content

Commit 4fe8388

Browse files
authored
Merge pull request #193 from Dominik-Vogel/refactor_doNds
Refactor doNds
2 parents 02684a9 + 96a0aff commit 4fe8388

File tree

1 file changed

+143
-143
lines changed

1 file changed

+143
-143
lines changed

qdev_wrappers/dataset/doNd.py

Lines changed: 143 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,90 @@
1-
from typing import Callable, Sequence, Union, Tuple, List, Optional
1+
from contextlib import contextmanager
2+
from typing import Callable, Sequence, Union, Tuple, List, Optional, Iterator
23
import os
34
import time
45

56
import numpy as np
67
import matplotlib
78
import matplotlib.pyplot as plt
89

9-
from qcodes.dataset.measurements import Measurement
10+
from qcodes.dataset.measurements import Measurement, res_type, DataSaver
1011
from qcodes.instrument.base import _BaseParameter
1112
from qcodes.dataset.plotting import plot_by_id
1213
from qcodes import config
1314

15+
ActionsT = Sequence[Callable[[], None]]
16+
17+
ParamMeasT = Union[_BaseParameter, Callable[[], None]]
18+
1419
AxesTuple = Tuple[matplotlib.axes.Axes, matplotlib.colorbar.Colorbar]
1520
AxesTupleList = Tuple[List[matplotlib.axes.Axes],
1621
List[Optional[matplotlib.colorbar.Colorbar]]]
1722
AxesTupleListWithRunId = Tuple[int, List[matplotlib.axes.Axes],
1823
List[Optional[matplotlib.colorbar.Colorbar]]]
19-
number = Union[float, int]
2024

2125

22-
def do0d(*param_meas: Union[_BaseParameter, Callable[[], None]],
23-
write_period: Optional[float] = None,
24-
do_plot: bool = True) -> AxesTupleListWithRunId:
26+
def _process_params_meas(param_meas: ParamMeasT) -> List[res_type]:
27+
output = []
28+
for parameter in param_meas:
29+
if isinstance(parameter, _BaseParameter):
30+
output.append((parameter, parameter.get()))
31+
elif callable(parameter):
32+
parameter()
33+
return output
34+
35+
36+
def _register_parameters(
37+
meas: Measurement,
38+
param_meas: List[ParamMeasT],
39+
setpoints: Optional[List[_BaseParameter]] = None
40+
) -> None:
41+
for parameter in param_meas:
42+
if isinstance(parameter, _BaseParameter):
43+
meas.register_parameter(parameter,
44+
setpoints=setpoints)
45+
46+
47+
def _register_actions(
48+
meas: Measurement,
49+
enter_actions: ActionsT,
50+
exit_actions: ActionsT
51+
) -> None:
52+
for action in enter_actions:
53+
# this omits the possibility of passing
54+
# argument to enter and exit actions.
55+
# Do we want that?
56+
meas.add_before_run(action, ())
57+
for action in exit_actions:
58+
meas.add_after_run(action, ())
59+
60+
61+
62+
def _set_write_period(
63+
meas: Measurement,
64+
write_period: Optional[float] = None
65+
) -> None:
66+
if write_period is not None:
67+
meas.write_period = write_period
68+
69+
70+
@contextmanager
71+
def _catch_keyboard_interrupts() -> Iterator[Callable[[], bool]]:
72+
interrupted = False
73+
def has_been_interrupted():
74+
nonlocal interrupted
75+
return interrupted
76+
try:
77+
yield has_been_interrupted
78+
except KeyboardInterrupt:
79+
interrupted = True
80+
81+
82+
83+
def do0d(
84+
*param_meas: ParamMeasT,
85+
write_period: Optional[float] = None,
86+
do_plot: bool = True
87+
) -> AxesTupleListWithRunId:
2588
"""
2689
Perform a measurement of a single parameter. This is probably most
2790
useful for an ArrayParamter that already returns an array of data points
@@ -38,41 +101,27 @@ def do0d(*param_meas: Union[_BaseParameter, Callable[[], None]],
38101
The run_id of the DataSet created
39102
"""
40103
meas = Measurement()
41-
if write_period is not None:
42-
meas.write_period = write_period
43-
output = []
44-
45-
for parameter in param_meas:
46-
meas.register_parameter(parameter)
47-
output.append([parameter, None])
104+
_register_parameters(meas, param_meas)
105+
_set_write_period(meas, write_period)
48106

49107
with meas.run() as datasaver:
108+
datasaver.add_result(*_process_params_meas(param_meas))
109+
110+
return _handle_plotting(datasaver, do_plot)
50111

51-
for i, parameter in enumerate(param_meas):
52-
if isinstance(parameter, _BaseParameter):
53-
output[i][1] = parameter.get()
54-
elif callable(parameter):
55-
parameter()
56-
datasaver.add_result(*output)
57-
dataid = datasaver.run_id
58112

59-
if do_plot is True:
60-
ax, cbs = _save_image(datasaver)
61-
else:
62-
ax = None,
63-
cbs = None
64113

65-
return dataid, ax, cbs
66114

67115

68-
def do1d(param_set: _BaseParameter, start: number, stop: number,
69-
num_points: int, delay: number,
70-
*param_meas: Union[_BaseParameter, Callable[[], None]],
71-
enter_actions: Sequence[Callable[[], None]] = (),
72-
exit_actions: Sequence[Callable[[], None]] = (),
73-
write_period: Optional[float] = None,
74-
do_plot: bool = True) \
75-
-> AxesTupleListWithRunId:
116+
def do1d(
117+
param_set: _BaseParameter, start: float, stop: float,
118+
num_points: int, delay: float,
119+
*param_meas: ParamMeasT,
120+
enter_actions: ActionsT = (),
121+
exit_actions: ActionsT = (),
122+
write_period: Optional[float] = None,
123+
do_plot: bool = True
124+
) -> AxesTupleListWithRunId:
76125
"""
77126
Perform a 1D scan of ``param_set`` from ``start`` to ``stop`` in
78127
``num_points`` measuring param_meas at each step. In case param_meas is
@@ -99,72 +148,38 @@ def do1d(param_set: _BaseParameter, start: number, stop: number,
99148
The run_id of the DataSet created
100149
"""
101150
meas = Measurement()
102-
if write_period is not None:
103-
meas.write_period = write_period
104-
meas.register_parameter(
105-
param_set) # register the first independent parameter
106-
output = []
151+
_register_parameters(meas, (param_set,))
152+
_register_parameters(meas, param_meas, setpoints=(param_set,))
153+
_set_write_period(meas, write_period)
154+
_register_actions(meas, enter_actions, exit_actions)
107155
param_set.post_delay = delay
108-
interrupted = False
109-
110-
for action in enter_actions:
111-
# this omits the posibility of passing
112-
# argument to enter and exit actions.
113-
# Do we want that?
114-
meas.add_before_run(action, ())
115-
for action in exit_actions:
116-
meas.add_after_run(action, ())
117156

118157
# do1D enforces a simple relationship between measured parameters
119158
# and set parameters. For anything more complicated this should be
120159
# reimplemented from scratch
121-
for parameter in param_meas:
122-
if isinstance(parameter, _BaseParameter):
123-
meas.register_parameter(parameter, setpoints=(param_set,))
124-
output.append([parameter, None])
125-
126-
try:
127-
with meas.run() as datasaver:
128-
129-
for set_point in np.linspace(start, stop, num_points):
130-
param_set.set(set_point)
131-
output = []
132-
for parameter in param_meas:
133-
if isinstance(parameter, _BaseParameter):
134-
output.append((parameter, parameter.get()))
135-
elif callable(parameter):
136-
parameter()
137-
datasaver.add_result((param_set, set_point),
138-
*output)
139-
except KeyboardInterrupt:
140-
interrupted = True
141-
142-
dataid = datasaver.run_id # convenient to have for plotting
143-
144-
if do_plot is True:
145-
ax, cbs = _save_image(datasaver)
146-
else:
147-
ax = None,
148-
cbs = None
149-
150-
if interrupted:
151-
raise KeyboardInterrupt
152-
return dataid, ax, cbs
153-
154-
155-
def do2d(param_set1: _BaseParameter, start1: number, stop1: number,
156-
num_points1: int, delay1: number,
157-
param_set2: _BaseParameter, start2: number, stop2: number,
158-
num_points2: int, delay2: number,
159-
*param_meas: Union[_BaseParameter, Callable[[], None]],
160-
set_before_sweep: Optional[bool] = False,
161-
enter_actions: Sequence[Callable[[], None]] = (),
162-
exit_actions: Sequence[Callable[[], None]] = (),
163-
before_inner_actions: Sequence[Callable[[], None]] = (),
164-
after_inner_actions: Sequence[Callable[[], None]] = (),
165-
write_period: Optional[float] = None,
166-
flush_columns: bool = False,
167-
do_plot: bool=True) -> AxesTupleListWithRunId:
160+
with _catch_keyboard_interrupts() as interrupted, meas.run() as datasaver:
161+
for set_point in np.linspace(start, stop, num_points):
162+
param_set.set(set_point)
163+
datasaver.add_result((param_set, set_point),
164+
*_process_params_meas(param_meas))
165+
return _handle_plotting(datasaver, do_plot, interrupted())
166+
167+
168+
def do2d(
169+
param_set1: _BaseParameter, start1: float, stop1: float,
170+
num_points1: int, delay1: float,
171+
param_set2: _BaseParameter, start2: float, stop2: float,
172+
num_points2: int, delay2: float,
173+
*param_meas: ParamMeasT,
174+
set_before_sweep: Optional[bool] = False,
175+
enter_actions: ActionsT = (),
176+
exit_actions: ActionsT = (),
177+
before_inner_actions: ActionsT = (),
178+
after_inner_actions: ActionsT = (),
179+
write_period: Optional[float] = None,
180+
flush_columns: bool = False,
181+
do_plot: bool=True
182+
) -> AxesTupleListWithRunId:
168183

169184
"""
170185
Perform a 1D scan of ``param_set1`` from ``start1`` to ``stop1`` in
@@ -202,29 +217,16 @@ def do2d(param_set1: _BaseParameter, start1: number, stop1: number,
202217
"""
203218

204219
meas = Measurement()
205-
if write_period is not None:
206-
meas.write_period = write_period
207-
meas.register_parameter(param_set1)
220+
_register_parameters(meas, (param_set1, param_set2))
221+
_register_parameters(meas, param_meas, setpoints=(param_set1, param_set2))
222+
_set_write_period(meas, write_period)
223+
_register_actions(meas, enter_actions, exit_actions)
224+
208225
param_set1.post_delay = delay1
209-
meas.register_parameter(param_set2)
210226
param_set2.post_delay = delay2
211-
interrupted = False
212-
for action in enter_actions:
213-
# this omits the possibility of passing
214-
# argument to enter and exit actions.
215-
# Do we want that?
216-
meas.add_before_run(action, ())
217-
218-
for action in exit_actions:
219-
meas.add_after_run(action, ())
220227

221-
for parameter in param_meas:
222-
if isinstance(parameter, _BaseParameter):
223-
meas.register_parameter(parameter,
224-
setpoints=(param_set1, param_set2))
225-
try:
226-
with meas.run() as datasaver:
227-
for set_point1 in np.linspace(start1, stop1, num_points1):
228+
with _catch_keyboard_interrupts() as interrupted, meas.run() as datasaver:
229+
for set_point1 in np.linspace(start1, stop1, num_points1):
228230
if set_before_sweep:
229231
param_set2.set(start2)
230232

@@ -237,67 +239,64 @@ def do2d(param_set1: _BaseParameter, start1: number, stop1: number,
237239
pass
238240
else:
239241
param_set2.set(set_point2)
240-
output = []
241-
for parameter in param_meas:
242-
if isinstance(parameter, _BaseParameter):
243-
output.append((parameter, parameter.get()))
244-
elif callable(parameter):
245-
parameter()
242+
246243
datasaver.add_result((param_set1, set_point1),
247244
(param_set2, set_point2),
248-
*output)
245+
*_process_params_meas(param_meas))
249246
for action in after_inner_actions:
250247
action()
251248
if flush_columns:
252249
datasaver.flush_data_to_database()
253-
except KeyboardInterrupt:
254-
interrupted = True
255250

256-
dataid = datasaver.run_id
251+
return _handle_plotting(datasaver, do_plot, interrupted())
257252

258-
if do_plot is True:
259-
ax, cbs = _save_image(datasaver)
260-
else:
261-
ax = None,
262-
cbs = None
263-
if interrupted:
264-
raise KeyboardInterrupt
265253

266-
return dataid, ax, cbs
267254

268255

269-
def _save_image(datasaver) -> AxesTupleList:
256+
def _handle_plotting(
257+
datasaver: DataSaver,
258+
do_plot: bool = True,
259+
interrupted: bool = False
260+
) -> AxesTupleList:
270261
"""
271262
Save the plots created by datasaver as pdf and png
272263
273264
Args:
274265
datasaver: a measurement datasaver that contains a dataset to be saved
275266
as plot.
267+
:param do_plot:
276268
277269
"""
278-
plt.ioff()
279270
dataid = datasaver.run_id
271+
if do_plot == True:
272+
res = _create_plots(datasaver)
273+
else:
274+
res = dataid, None, None
275+
276+
if interrupted:
277+
raise KeyboardInterrupt
278+
279+
return res
280+
281+
282+
def _create_plots(datasaver: DataSaver) -> AxesTupleList:
283+
dataid = datasaver.run_id
284+
plt.ioff()
280285
start = time.time()
281286
axes, cbs = plot_by_id(dataid)
282287
stop = time.time()
283-
print(f"plot by id took {stop-start}")
284-
288+
print(f"plot by id took {stop - start}")
285289
mainfolder = config.user.mainfolder
286290
experiment_name = datasaver._dataset.exp_name
287291
sample_name = datasaver._dataset.sample_name
288-
289292
storage_dir = os.path.join(mainfolder, experiment_name, sample_name)
290293
os.makedirs(storage_dir, exist_ok=True)
291-
292294
png_dir = os.path.join(storage_dir, 'png')
293295
pdf_dif = os.path.join(storage_dir, 'pdf')
294-
295296
os.makedirs(png_dir, exist_ok=True)
296297
os.makedirs(pdf_dif, exist_ok=True)
297-
298298
save_pdf = True
299299
save_png = True
300-
301300
for i, ax in enumerate(axes):
302301
if save_pdf:
303302
full_path = os.path.join(pdf_dif, f'{dataid}_{i}.pdf')
@@ -306,4 +305,5 @@ def _save_image(datasaver) -> AxesTupleList:
306305
full_path = os.path.join(png_dir, f'{dataid}_{i}.png')
307306
ax.figure.savefig(full_path, dpi=500)
308307
plt.ion()
309-
return axes, cbs
308+
res = dataid, axes, cbs
309+
return res

0 commit comments

Comments
 (0)