Skip to content

Commit 9fdc011

Browse files
nkanazawa1989eggerdjchriseclectic
authored
CurveAnalysis base class (#765)
Refactor of CurveAnalysis and introduction of BaseCurveAnalysis class Co-authored-by: Daniel J. Egger <[email protected]> Co-authored-by: Christopher J. Wood <[email protected]>
1 parent 10e0bee commit 9fdc011

20 files changed

+1677
-1235
lines changed

qiskit_experiments/curve_analysis/__init__.py

Lines changed: 461 additions & 37 deletions
Large diffs are not rendered by default.

qiskit_experiments/curve_analysis/base_curve_analysis.py

Lines changed: 547 additions & 0 deletions
Large diffs are not rendered by default.

qiskit_experiments/curve_analysis/curve_analysis.py

Lines changed: 124 additions & 834 deletions
Large diffs are not rendered by default.

qiskit_experiments/curve_analysis/curve_data.py

Lines changed: 99 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -25,94 +25,138 @@
2525

2626
@dataclasses.dataclass(frozen=True)
2727
class SeriesDef:
28-
"""Description of curve."""
28+
"""A dataclass to describe the definition of the curve.
29+
30+
Attributes:
31+
fit_func: A callable that defines the fit model of this curve. The argument names
32+
in the callable are parsed to create the fit parameter list, which will appear
33+
in the analysis results. The first argument should be ``x`` that represents
34+
X-values that the experiment sweeps.
35+
filter_kwargs: Optional. Dictionary of properties that uniquely identifies this series.
36+
This dictionary is used for data processing.
37+
This must be provided when the curve analysis consists of multiple series.
38+
name: Optional. Name of this series.
39+
plot_color: Optional. String representation of the color that is used to draw fit data
40+
and data points in the output figure. This depends on the drawer class
41+
being set to the curve analysis options. Usually this conforms to the
42+
Matplotlib color names.
43+
plot_symbol: Optional. String representation of the marker shape that is used to draw
44+
data points in the output figure. This depends on the drawer class
45+
being set to the curve analysis options. Usually this conforms to the
46+
Matplotlib symbol names.
47+
canvas: Optional. Index of sub-axis in the output figure that draws this curve.
48+
This option is valid only when the drawer instance provides multi-axis drawing.
49+
model_description: Optional. Arbitrary string representation of this fit model.
50+
This string will appear in the analysis results as a part of metadata.
51+
"""
2952

30-
# Arbitrary callback to define the fit function. First argument should be x.
3153
fit_func: Callable
32-
33-
# Keyword dictionary to define the series with circuit metadata
3454
filter_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
35-
36-
# Name of this series. This name will appear in the figure and raw x-y value report.
3755
name: str = "Series-0"
38-
39-
# Color of this line.
4056
plot_color: str = "black"
41-
42-
# Symbol to represent data points of this line.
4357
plot_symbol: str = "o"
44-
45-
# Latex description of this fit model
46-
model_description: Optional[str] = None
47-
48-
# Index of canvas if the result figure is multi-panel
4958
canvas: Optional[int] = None
50-
51-
# Automatically extracted signature of the fit function
52-
signature: List[str] = dataclasses.field(init=False)
59+
model_description: Optional[str] = None
60+
signature: Tuple[str, ...] = dataclasses.field(init=False)
5361

5462
def __post_init__(self):
5563
"""Parse the fit function signature to extract the names of the variables.
5664
5765
Fit functions take arguments F(x, p0, p1, p2, ...) thus the first value should be excluded.
5866
"""
5967
signature = list(inspect.signature(self.fit_func).parameters.keys())
60-
fitparams = signature[1:]
68+
fitparams = tuple(signature[1:])
6169

6270
# Note that this dataclass is frozen
6371
object.__setattr__(self, "signature", fitparams)
6472

6573

6674
@dataclasses.dataclass(frozen=True)
6775
class CurveData:
68-
"""Set of extracted experiment data."""
69-
70-
# Name of this data set
71-
label: str
76+
"""A dataclass that manages the multiple arrays comprising the dataset for fitting.
77+
78+
This dataset can consist of X, Y values from multiple series.
79+
To extract curve data of the particular series, :meth:`get_subset_of` can be used.
80+
81+
Attributes:
82+
x: X-values that experiment sweeps.
83+
y: Y-values that observed and processed by the data processor.
84+
y_err: Uncertainty of the Y-values which is created by the data processor.
85+
Usually this assumes standard error.
86+
shots: Number of shots used in the experiment to obtain the Y-values.
87+
data_allocation: List with identical size with other arrays.
88+
The value indicates the series index of the corresponding element.
89+
This is classified based upon the matching of :attr:`SeriesDef.filter_kwargs`
90+
with the circuit metadata of the corresponding data index.
91+
If metadata doesn't match with any series definition, element is filled with ``-1``.
92+
labels: List of curve labels. The list index corresponds to the series index.
93+
"""
7294

73-
# X data
7495
x: np.ndarray
75-
76-
# Y data (measured data)
7796
y: np.ndarray
78-
79-
# Error bar
8097
y_err: np.ndarray
81-
82-
# Shots number
8398
shots: np.ndarray
99+
data_allocation: np.ndarray
100+
labels: List[str]
84101

85-
# Maping of data index to series index
86-
data_index: Union[np.ndarray, int]
102+
def get_subset_of(self, index: Union[str, int]) -> "CurveData":
103+
"""Filter data by series name or index.
87104
88-
# Metadata associated with each data point. Generated from the circuit metadata.
89-
metadata: np.ndarray = None
105+
Args:
106+
index: Series index of name.
107+
108+
Returns:
109+
A subset of data corresponding to a particular series.
110+
"""
111+
if isinstance(index, int):
112+
_index = index
113+
_name = self.labels[index]
114+
else:
115+
_index = self.labels.index(index)
116+
_name = index
117+
118+
locs = self.data_allocation == _index
119+
return CurveData(
120+
x=self.x[locs],
121+
y=self.y[locs],
122+
y_err=self.y_err[locs],
123+
shots=self.shots[locs],
124+
data_allocation=np.full(np.count_nonzero(locs), _index),
125+
labels=[_name],
126+
)
90127

91128

92129
@dataclasses.dataclass(frozen=True)
93130
class FitData:
94-
"""Set of data generated by the fit function."""
131+
"""A dataclass to store the outcome of the fitting.
132+
133+
Attributes:
134+
popt: List of optimal parameter values with uncertainties if available.
135+
popt_keys: List of parameter names being fit.
136+
pcov: Covariance matrix from the least square fitting.
137+
reduced_chisq: Reduced Chi-squared value for the fit curve.
138+
dof: Degree of freedom in this fit model.
139+
x_data: X-values provided to the fitter.
140+
y_data: Y-values provided to the fitter.
141+
"""
95142

96-
# Order sensitive fit parameter values
97143
popt: List[uncertainties.UFloat]
98-
99-
# Order sensitive parameter name list
100144
popt_keys: List[str]
101-
102-
# Covariance matrix
103145
pcov: np.ndarray
104-
105-
# Reduced Chi-squared value of fit curve
106146
reduced_chisq: float
107-
108-
# Degree of freedom
109147
dof: int
148+
x_data: np.ndarray
149+
y_data: np.ndarray
110150

111-
# X data range
112-
x_range: Tuple[float, float]
151+
@property
152+
def x_range(self) -> Tuple[float, float]:
153+
"""Range of x values."""
154+
return np.min(self.x_data), np.max(self.x_data)
113155

114-
# Y data range
115-
y_range: Tuple[float, float]
156+
@property
157+
def y_range(self) -> Tuple[float, float]:
158+
"""Range of y values."""
159+
return np.min(self.y_data), np.max(self.y_data)
116160

117161
def fitval(self, key: str) -> uncertainties.UFloat:
118162
"""A helper method to get fit value object from parameter key name.
@@ -136,7 +180,13 @@ def fitval(self, key: str) -> uncertainties.UFloat:
136180

137181
@dataclasses.dataclass
138182
class ParameterRepr:
139-
"""Detailed description of fitting parameter."""
183+
"""Detailed description of fitting parameter.
184+
185+
Attributes:
186+
name: Original name of the fit parameter being defined in the fit model.
187+
repr: Optional. Human-readable parameter name shown in the analysis result and in the figure.
188+
unit: Optional. Physical unit of this parameter if applicable.
189+
"""
140190

141191
# Fitter argument name
142192
name: str

qiskit_experiments/curve_analysis/curve_fit.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,18 +155,14 @@ def fit_func(x, *params):
155155
residues = residues / (sigma**2)
156156
reduced_chisq = np.sum(residues) / dof
157157

158-
# Compute data range for fit
159-
xdata_range = np.min(xdata), np.max(xdata)
160-
ydata_range = np.min(ydata), np.max(ydata)
161-
162158
return FitData(
163159
popt=list(fit_params),
164160
popt_keys=list(param_keys),
165161
pcov=pcov,
166162
reduced_chisq=reduced_chisq,
167163
dof=dof,
168-
x_range=xdata_range,
169-
y_range=ydata_range,
164+
x_data=xdata,
165+
y_data=ydata,
170166
)
171167

172168

qiskit_experiments/curve_analysis/standard_analysis/decay.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,19 @@ class DecayAnalysis(curve.CurveAnalysis):
6161
]
6262

6363
def _generate_fit_guesses(
64-
self, user_opt: curve.FitOptions
64+
self,
65+
user_opt: curve.FitOptions,
66+
curve_data: curve.CurveData,
6567
) -> Union[curve.FitOptions, List[curve.FitOptions]]:
66-
"""Compute the initial guesses.
68+
"""Create algorithmic guess with analysis options and curve data.
6769
6870
Args:
6971
user_opt: Fit options filled with user provided guess and bounds.
72+
curve_data: Formatted data collection to fit.
7073
7174
Returns:
7275
List of fit options that are passed to the fitter function.
73-
74-
Raises:
75-
AnalysisError: When the y data is likely constant.
7676
"""
77-
curve_data = self._data()
78-
7977
user_opt.p0.set_if_empty(base=curve.guess.min_height(curve_data.y)[0])
8078

8179
alpha = curve.guess.exp_decay(curve_data.x, curve_data.y)

qiskit_experiments/curve_analysis/standard_analysis/error_amplification_analysis.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,6 @@ class ErrorAmplificationAnalysis(curve.CurveAnalysis):
7676
often correspond to symmetry points of the fit function. Furthermore,
7777
this type of analysis is intended for values of :math:`d\theta` close to zero.
7878
79-
# section: note
80-
81-
Different analysis classes may subclass this class to fix some of the fit parameters.
8279
"""
8380

8481
__series__ = [
@@ -109,7 +106,7 @@ def _default_options(cls):
109106
considered as good. Defaults to :math:`\pi/2`.
110107
"""
111108
default_options = super()._default_options()
112-
default_options.curve_plotter.set_options(
109+
default_options.curve_drawer.set_options(
113110
xlabel="Number of gates (n)",
114111
ylabel="Population",
115112
ylim=(0, 1.0),
@@ -120,22 +117,21 @@ def _default_options(cls):
120117
return default_options
121118

122119
def _generate_fit_guesses(
123-
self, user_opt: curve.FitOptions
120+
self,
121+
user_opt: curve.FitOptions,
122+
curve_data: curve.CurveData,
124123
) -> Union[curve.FitOptions, List[curve.FitOptions]]:
125-
"""Compute the initial guesses.
124+
"""Create algorithmic guess with analysis options and curve data.
126125
127126
Args:
128127
user_opt: Fit options filled with user provided guess and bounds.
128+
curve_data: Formatted data collection to fit.
129129
130130
Returns:
131131
List of fit options that are passed to the fitter function.
132-
133-
Raises:
134-
CalibrationError: When ``angle_per_gate`` is missing.
135132
"""
136133
fixed_params = self.options.fixed_parameters
137134

138-
curve_data = self._data()
139135
max_abs_y, _ = curve.guess.max_height(curve_data.y, absolute=True)
140136
max_y, min_y = np.max(curve_data.y), np.min(curve_data.y)
141137

qiskit_experiments/curve_analysis/standard_analysis/gaussian.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class GaussianAnalysis(curve.CurveAnalysis):
7171
@classmethod
7272
def _default_options(cls) -> Options:
7373
options = super()._default_options()
74-
options.curve_plotter.set_options(
74+
options.curve_drawer.set_options(
7575
xlabel="Frequency",
7676
ylabel="Signal (arb. units)",
7777
xval_unit="Hz",
@@ -81,17 +81,19 @@ def _default_options(cls) -> Options:
8181
return options
8282

8383
def _generate_fit_guesses(
84-
self, user_opt: curve.FitOptions
84+
self,
85+
user_opt: curve.FitOptions,
86+
curve_data: curve.CurveData,
8587
) -> Union[curve.FitOptions, List[curve.FitOptions]]:
86-
"""Compute the initial guesses.
88+
"""Create algorithmic guess with analysis options and curve data.
8789
8890
Args:
8991
user_opt: Fit options filled with user provided guess and bounds.
92+
curve_data: Formatted data collection to fit.
9093
9194
Returns:
9295
List of fit options that are passed to the fitter function.
9396
"""
94-
curve_data = self._data()
9597
max_abs_y, _ = curve.guess.max_height(curve_data.y, absolute=True)
9698

9799
user_opt.bounds.set_if_empty(
@@ -128,22 +130,18 @@ def _evaluate_quality(self, fit_data: curve.FitData) -> Union[str, None]:
128130
threshold of two, and
129131
- a standard error on the sigma of the Gaussian that is smaller than the sigma.
130132
"""
131-
curve_data = self._data()
132-
133-
max_freq = np.max(curve_data.x)
134-
min_freq = np.min(curve_data.x)
135-
freq_increment = np.mean(np.diff(curve_data.x))
133+
freq_increment = np.mean(np.diff(fit_data.x_data))
136134

137135
fit_a = fit_data.fitval("a")
138136
fit_b = fit_data.fitval("b")
139137
fit_freq = fit_data.fitval("freq")
140138
fit_sigma = fit_data.fitval("sigma")
141139

142-
snr = abs(fit_a.n) / np.sqrt(abs(np.median(curve_data.y) - fit_b.n))
143-
fit_width_ratio = fit_sigma.n / (max_freq - min_freq)
140+
snr = abs(fit_a.n) / np.sqrt(abs(np.median(fit_data.y_data) - fit_b.n))
141+
fit_width_ratio = fit_sigma.n / np.ptp(fit_data.x_data)
144142

145143
criteria = [
146-
min_freq <= fit_freq.n <= max_freq,
144+
fit_data.x_range[0] <= fit_freq.n <= fit_data.x_range[1],
147145
1.5 * freq_increment < fit_sigma.n,
148146
fit_width_ratio < 0.25,
149147
fit_data.reduced_chisq < 3,

0 commit comments

Comments
 (0)