Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
5fda87b
couple of bugfixes in leakage.py (make leaky_qubit_model_from_pspec r…
rileyjmurray Oct 17, 2025
70314b0
remove spurious second `arg = `
rileyjmurray Oct 22, 2025
54f6b55
input validation and doc changes for DirectSumUnitaryGroup
rileyjmurray Oct 22, 2025
f154b93
correct type annotations
rileyjmurray Oct 24, 2025
30d52d2
add private _direct_sum_unitary_group helper function in leakage.py. …
rileyjmurray Oct 24, 2025
0ef2bb0
have leakage_automagic.ipynb demonstrate ability to identify when a s…
rileyjmurray Oct 24, 2025
681bab9
notebook: reduce GST circuit depth and clear output
rileyjmurray Oct 24, 2025
7018fa6
steal infrastructure for setting objective function tolerances from a…
rileyjmurray Oct 24, 2025
3c8ef6e
make it slightly cheaper to run the example notebook. Make a proper c…
rileyjmurray Oct 24, 2025
4594615
Merge branch 'bugfix' into resolve-issue-652-smaller-scope-bugfix
rileyjmurray Oct 24, 2025
1acb0ed
actually use string specification of basis rather than insisting on str
rileyjmurray Oct 28, 2025
2e35382
undo accidental changes to objectivefns.py
rileyjmurray Oct 28, 2025
97eeee5
clarifying comment
rileyjmurray Oct 28, 2025
06cff96
Introduce ModelMember._to_transformed_dense. This has detailed semant…
rileyjmurray Oct 29, 2025
aca56bf
adjust imports; remove spurrious print statement
rileyjmurray Oct 29, 2025
62a8765
temporary implementations of povm_diamonddist and inst_diamonddist to…
rileyjmurray Oct 29, 2025
7a0a37c
typo
rileyjmurray Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions jupyter_notebooks/Examples/Leakage-automagic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
"metadata": {},
"outputs": [],
"source": [
"from pygsti.modelpacks import smq1Q_XY, smq1Q_ZN\n",
"from pygsti.modelpacks import smq1Q_XYI as mp\n",
"from pygsti.tools.leakage import leaky_qubit_model_from_pspec, construct_leakage_report\n",
"from pygsti.data import simulate_data\n",
"from pygsti.protocols import StandardGST, ProtocolData"
"from pygsti.protocols import StandardGST, ProtocolData\n",
"import numpy as np\n",
"import scipy.linalg as la"
]
},
{
Expand All @@ -27,14 +29,51 @@
"metadata": {},
"outputs": [],
"source": [
"mp = smq1Q_XY\n",
"ed = mp.create_gst_experiment_design(max_max_length=32)\n",
"def with_leaky_gate(m, gate_label, strength):\n",
" rng = np.random.default_rng(0)\n",
" v = np.concatenate([[0.0], rng.standard_normal(size=(2,))])\n",
" v /= la.norm(v)\n",
" H = v.reshape((-1, 1)) @ v.reshape((1, -1))\n",
" H *= strength\n",
" U = la.expm(1j*H)\n",
" m_copy = m.copy()\n",
" G_ideal = m_copy.operations[gate_label]\n",
" from pygsti.modelmembers.operations import ComposedOp, StaticUnitaryOp\n",
" m_copy.operations[gate_label] = ComposedOp([G_ideal, StaticUnitaryOp(U, basis=m.basis)])\n",
" return m_copy, v\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ed = mp.create_gst_experiment_design(max_max_length=8)\n",
"# ^ The default max length is small so we don't have to wait as long \n",
"# for the GST fit (just for purposes of this notebook).\n",
"tm3 = leaky_qubit_model_from_pspec(mp.processor_spec(), mx_basis='l2p1')\n",
"# ^ We could use basis = 'gm' instead of 'l2p1'. We prefer 'l2p1'\n",
"# because it makes process matrices easier to interpret in leakage\n",
"# modeling.\n",
"ds = simulate_data(tm3, ed.all_circuits_needing_data, num_samples=1000, seed=1997)\n",
"gst = StandardGST( modes=('CPTPLND',), target_model=tm3, verbosity=2)\n",
"# ^ Target model. \"Leaky\" is a bit of a misnomer here. The returned model\n",
"# is simply a qutrit lift of the qubit model; leakage erorrs in the\n",
"# qubit model can manifest as CPTP Markovian errors in the qutrit model.\n",
"dgm3, leaking_state = with_leaky_gate(tm3, ('Gxpi2', 0), strength=0.125)\n",
"# ^ Data generating model. \n",
"num_samples = 100_000\n",
"# ^ The number of samples is large to compensate for short circuit length.\n",
"# Feel free to change the number of samples to something more \"realistic\"\n",
"# if you'd like.\n",
"if num_samples > 10_000:\n",
" from pygsti.objectivefns import objectivefns\n",
" objectivefns.DEFAULT_MIN_PROB_CLIP = objectivefns.DEFAULT_RADIUS = 1e-12\n",
" # ^ There are numerical thresholding rules in objective function evaluation\n",
" # that lead to errors when the number of samples is extremely large.\n",
" # The lines above change those thresholding rules to be appropriate in\n",
" # the unusual setting that is this notebook.\n",
"ds = simulate_data(dgm3, ed.all_circuits_needing_data, num_samples=num_samples, seed=1997)\n",
"gst = StandardGST(\n",
" modes=('CPTPLND',), target_model=tm3, verbosity=2,\n",
" badfit_options={'actions': ['wildcard1d'], 'threshold': 0.0}\n",
")\n",
"pd = ProtocolData(ed, ds)\n",
"res = gst.run(pd)"
]
Expand Down
74 changes: 74 additions & 0 deletions pygsti/modelmembers/modelmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
# http://www.apache.org/licenses/LICENSE-2.0 or in the LICENSE file in the root pyGSTi directory.
#***************************************************************************************************

from __future__ import annotations

from collections import OrderedDict
import copy as _copy

import numpy as _np

from pygsti.baseobjs.nicelyserializable import NicelySerializable as _NicelySerializable
from pygsti.pgtypes import SpaceT
from pygsti.tools import listtools as _lt
from pygsti.tools import slicetools as _slct
from pygsti.tools import matrixtools as _mt

from typing import Optional


class ModelChild(object):
Expand Down Expand Up @@ -802,6 +808,74 @@ def copy(self, parent=None, memo=None):
def to_dense(self) -> _np.ndarray:
raise NotImplementedError('Derived classes must implement .to_dense().')

def _to_transformed_dense(self, T_domain: _mt.OperatorLike, T_codomain: _mt.OperatorLike, on_space: SpaceT='minimal') -> _np.ndarray:
"""
Return an array, XT, obtained by suitably transforming X = self.to_dense(on_space).

The basic nature of the transformation X --> XT depends on the category of `self`,
as determined by its domain and codomain.

| abstract category | domain | codomain |
| ----------------- | ------------ | ------------ |
| vector | field | vector space |
| functional | vector space | field |
| operator | vector space | vector space |

To state the specific transformation X --> XT, let op(X) denote the operator
representation of X obtained by (1) interpreting fields as 1-dimensional vector
spaces, and (2) having linear operators act on vectors by left-multiplication.

The returned array, XT, is defined through its op(XT) representation:

| abstract category | op(XT) representation of XT |
| ----------------- | ----------------------------- |
| vector | T_codomain @ op(X) |
| functional | op(X) @ T_domain |
| operator | T_codomain @ op(X) @ T_domain |

Note that T_domain is ignored for abstract vectors (i.e., state prep), and T_codomain
is ignored for abstract functionals (i.e., POVM effects).
"""
raise NotImplementedError()

def residuals(self, other: ModelMember,
transform: Optional[_mt.OperatorLike]=None, inv_transform: Optional[_mt.OperatorLike]=None
) -> _np.ndarray:
# This implementation was introduced as part of a heavy refactor, but it preserves all intended
# semantics of the old implementation.
T_domain = _mt.to_operatorlike(transform)
T_codomain = _mt.to_operatorlike(inv_transform)
# ^ to_operatorlike casts None to IdentityOperator
X = self._to_transformed_dense(T_domain, T_codomain)
if isinstance(inv_transform, _mt.IdentityOperator):
# Passing inv_transform as an IdentityOperator (rather than casting from None)
# is a flag. It indicates that we want to apply `transform` to `other` as well.
#
# (Yes, this sort of flag interpretation is bad design. No, I don't want to
# spend the time on a good design.)
Y = other._to_transformed_dense(T_domain, inv_transform)
else:
Y = other.to_dense()
return (X - Y).ravel()

def frobeniusdist_squared(self, other: ModelMember,
transform: Optional[_mt.OperatorLike]=None, inv_transform: Optional[_mt.OperatorLike]=None
) -> _np.floating:
"""
Return the squared Frobenius norm of the difference between `self` and `other`,
possibly after transformation by `transform` and/or `inv_transform`.
"""
return _np.linalg.norm(self.residuals(other, transform, inv_transform))**2

def frobeniusdist(self, other: ModelMember,
transform: Optional[_mt.OperatorLike]=None, inv_transform: Optional[_mt.OperatorLike]=None
) -> _np.floating:
"""
Return the Frobenius norm of the difference between `self` and `other`,
possibly after transformation by `transform` and/or `inv_transform`.
"""
return _np.linalg.norm(self.residuals(other, transform, inv_transform))

def _is_similar(self, other, rtol, atol):
""" Returns True if `other` model member (which it guaranteed to be the same type as self) has
the same local structure, i.e., not considering parameter values or submembers """
Expand Down
103 changes: 11 additions & 92 deletions pygsti/modelmembers/operations/linearop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
# http://www.apache.org/licenses/LICENSE-2.0 or in the LICENSE file in the root pyGSTi directory.
#***************************************************************************************************

from __future__ import annotations

import numpy as _np

from pygsti.baseobjs.opcalc import bulk_eval_compact_polynomials_complex as _bulk_eval_compact_polynomials_complex
from pygsti.modelmembers import modelmember as _modelmember
from pygsti.tools import optools as _ot
from pygsti.tools import matrixtools as _mt
from pygsti import SpaceT

from typing import Any

#Note on initialization sequence of Operations within a Model:
# 1) a Model is constructed (empty)
# 2) a LinearOperator is constructed - apart from a Model if it's locally parameterized,
Expand Down Expand Up @@ -156,6 +157,14 @@ def to_dense(self, on_space: SpaceT='minimal'):
"""
raise NotImplementedError("to_dense(...) not implemented for %s objects!" % self.__class__.__name__)

def _to_transformed_dense(self, T_domain: _mt.OperatorLike, T_codomain: _mt.OperatorLike, on_space: SpaceT='minimal') -> _np.ndarray:
"""
Return an array representation of the linear operator obtained by composing T_domain,
self.to_dense(), and T_codomain --- in that order.
"""
out = T_codomain @ self.to_dense(on_space=on_space) @ T_domain
return out

def acton(self, state, on_space='minimal'):
"""
Act with this operator upon `state`
Expand Down Expand Up @@ -391,96 +400,6 @@ def taylor_order_terms_above_mag(self, order, max_polynomial_vars, min_term_mag)

return [t for t in terms_at_order if t.magnitude >= min_term_mag]

def frobeniusdist_squared(self, other_op, transform=None, inv_transform=None) -> _np.floating[Any]:
"""
Return the squared frobenius difference between this operation and `other_op`

Optionally transforms this operation first using matrices
`transform` and `inv_transform`. Specifically, this operation gets
transformed as: `O => inv_transform * O * transform` before comparison with
`other_op`.

Parameters
----------
other_op : DenseOperator
The other operation.

transform : numpy.ndarray, optional
Transformation matrix.

inv_transform : numpy.ndarray, optional
Inverse of `transform`.

Returns
-------
float
"""
self_mx = self.to_dense("minimal")
if transform is not None:
self_mx = self_mx @ transform
if inv_transform is not None:
self_mx = inv_transform @ self_mx
return _ot.frobeniusdist_squared(self_mx, other_op.to_dense("minimal"))


def frobeniusdist(self, other_op, transform=None, inv_transform=None):
"""
Return the frobenius distance between this operation and `other_op`.

Optionally transforms this operation first using matrices
`transform` and `inv_transform`. Specifically, this operation gets
transformed as: `O => inv_transform * O * transform` before comparison with
`other_op`.

Parameters
----------
other_op : DenseOperator
The other operation.

transform : numpy.ndarray, optional
Transformation matrix.

inv_transform : numpy.ndarray, optional
Inverse of `transform`.

Returns
-------
float
"""
return _np.sqrt(self.frobeniusdist_squared(other_op, transform, inv_transform))

def residuals(self, other_op, transform=None, inv_transform=None):
"""
The per-element difference between this `DenseOperator` and `other_op`.

Optionally, tansforming this operation first as
`O => inv_transform * O * transform`.

Parameters
----------
other_op : DenseOperator
The operation to compare against.

transform : numpy.ndarray, optional
Transformation matrix.

inv_transform : numpy.ndarray, optional
Inverse of `transform`.

Returns
-------
numpy.ndarray
A 1D-array of size equal to that of the flattened operation matrix.
"""
dense_self = self.to_dense("minimal")
if transform is not None:
assert inv_transform is not None
dense_self = inv_transform @ (dense_self @ transform)
else:
assert inv_transform is None
return (dense_self - other_op.to_dense("minimal")).ravel()


def jtracedist(self, other_op, transform=None, inv_transform=None):
"""
Return the Jamiolkowski trace distance between this operation and `other_op`.
Expand Down
67 changes: 14 additions & 53 deletions pygsti/modelmembers/povms/effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
# in compliance with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0 or in the LICENSE file in the root pyGSTi directory.
#***************************************************************************************************
from __future__ import annotations

import numpy as _np

from pygsti.pgtypes import SpaceT
from pygsti.modelmembers import modelmember as _modelmember
from pygsti.tools import optools as _ot
from pygsti.tools import matrixtools as _mt
from pygsti.baseobjs.opcalc import bulk_eval_compact_polynomials_complex as _bulk_eval_compact_polynomials_complex

from typing import Any


class POVMEffect(_modelmember.ModelMember):
"""
Expand Down Expand Up @@ -119,59 +120,19 @@ def set_time(self, t):

## PUT term calc methods here if appropriate...

def frobeniusdist_squared(self, other_spam_vec, transform=None, inv_transform=None) -> _np.floating[Any]:
"""
Return the squared frobenius difference between this effect and `other_spam_vec`.

Optionally transforms this vector first using `transform`.

Parameters
----------
other_spam_vec : POVMEffect
The other spam vector

transform : numpy.ndarray, optional
Transformation matrix.

inv_transform : numpy.ndarray, optional
Ignored. (We keep this as a positional argument for consistency with
the frobeniusdist_squared method of pyGSTi's LinearOperator objects.)

Returns
-------
float
"""
vec = self.to_dense()
if transform is not None:
vec = transform.T @ vec
return _ot.frobeniusdist_squared(vec, other_spam_vec.to_dense())

def residuals(self, other_spam_vec, transform=None, inv_transform=None):
def _to_transformed_dense(self, T_domain: _mt.OperatorLike, T_codomain: _mt.OperatorLike, on_space: SpaceT='minimal') -> _np.ndarray:
"""
Return a vector of residuals between this spam vector and `other_spam_vec`.

Optionally transforms this vector first using `transform` and
`inv_transform`.

Parameters
----------
other_spam_vec : POVMEffect
The other spam vector

transform : numpy.ndarray, optional
Transformation matrix.

inv_transform : numpy.ndarray, optional
Inverse of `tranform`.

Returns
-------
float
Return an array representation of the linear operator obtained by composing T_domain
and self.to_dense(). The representation interprets POVM effects as linear functionals
on density matrices. It allows for --- but does not strictly require --- the convention
that POVM effects are represented as column-vector superkets.

T_codomain (ignored) is only here for consistency across the ModelMember API.
"""
vec = self.to_dense()
if transform is not None:
vec = transform.T @ vec
return (vec - other_spam_vec.to_dense()).ravel()
X = self.to_dense(on_space=on_space) # type: ignore
assert X.ndim == 1 or X.shape[1] == 1
out = T_domain.T @ X
return out

def transform_inplace(self, s):
"""
Expand Down
Loading