Skip to content

Commit 6119620

Browse files
authored
Add Transformer API Interface and @cirq.transformer decorator (#4797)
Defines `TRANSFORMER_TYPE` and Implements the `@transformer` decorator, as proposed in https://tinyurl.com/cirq-circuit-transformers-api All existing transformers will be rewritten to follow the new API once this is checked-in. Implementation of `TransformerStatsLoggerBase` will follow in a separate PR. Part of #4483 cc @maffoo PTAL at all the mypy magic.
1 parent 394b9cf commit 6119620

File tree

5 files changed

+564
-0
lines changed

5 files changed

+564
-0
lines changed

cirq-core/cirq/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,11 @@
365365
single_qubit_matrix_to_phased_x_z,
366366
single_qubit_matrix_to_phxz,
367367
single_qubit_op_to_framed_phase_form,
368+
TRANSFORMER,
369+
TransformerContext,
370+
TransformerLogger,
368371
three_qubit_matrix_to_operations,
372+
transformer,
369373
two_qubit_matrix_to_diagonal_and_operations,
370374
two_qubit_matrix_to_operations,
371375
two_qubit_matrix_to_sqrt_iswap_operations,

cirq-core/cirq/protocols/json_test_data/spec.py

+4
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@
110110
'MergeSingleQubitGates',
111111
'PointOptimizer',
112112
'SynchronizeTerminalMeasurements',
113+
# Transformers
114+
'TransformerLogger',
115+
'TransformerContext',
113116
# global objects
114117
'CONTROL_TAG',
115118
'PAULI_BASIS',
@@ -172,6 +175,7 @@
172175
'Sweepable',
173176
'TParamKey',
174177
'TParamVal',
178+
'TRANSFORMER',
175179
'ParamDictType',
176180
# utility:
177181
'CliffordSimulator',

cirq-core/cirq/transformers/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@
4141
two_qubit_gate_product_tabulation,
4242
)
4343

44+
from cirq.transformers.transformer_api import (
45+
LogLevel,
46+
TRANSFORMER,
47+
TransformerContext,
48+
TransformerLogger,
49+
transformer,
50+
)
51+
4452
from cirq.transformers.transformer_primitives import (
4553
map_moments,
4654
map_operations,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
# Copyright 2022 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Defines the API for circuit transformers in Cirq."""
16+
17+
import textwrap
18+
import functools
19+
from typing import (
20+
Any,
21+
Callable,
22+
Tuple,
23+
Hashable,
24+
List,
25+
Type,
26+
overload,
27+
TYPE_CHECKING,
28+
)
29+
import dataclasses
30+
import enum
31+
from cirq.circuits.circuit import CIRCUIT_TYPE
32+
33+
if TYPE_CHECKING:
34+
import cirq
35+
36+
37+
class LogLevel(enum.Enum):
38+
"""Different logging resolution options for `cirq.TransformerLogger`.
39+
40+
The enum values of the logging levels are used to filter the stored logs when printing.
41+
In general, a logging level `X` includes all logs stored at a level >= 'X'.
42+
43+
Args:
44+
ALL: All levels. Used to filter logs when printing.
45+
DEBUG: Designates fine-grained informational events that are most useful to debug /
46+
understand in-depth any unexpected behavior of the transformer.
47+
INFO: Designates informational messages that highlight the actions of a transformer.
48+
WARNING: Designates unwanted or potentially harmful situations.
49+
NONE: No levels. Used to filter logs when printing.
50+
"""
51+
52+
ALL = 0
53+
DEBUG = 10
54+
INFO = 20
55+
WARNING = 30
56+
NONE = 40
57+
58+
59+
@dataclasses.dataclass
60+
class _LoggerNode:
61+
"""Stores logging data of a single transformer stage in `cirq.TransformerLogger`.
62+
63+
The class is used to define a logging graph to store logs of sequential or nested transformers.
64+
Each node corresponds to logs of a single transformer stage.
65+
66+
Args:
67+
transformer_id: Integer specifying a unique id for corresponding transformer stage.
68+
transformer_name: Name of the corresponding transformer stage.
69+
initial_circuit: Initial circuit before the transformer stage began.
70+
final_circuit: Final circuit after the transformer stage ended.
71+
logs: Messages logged by the transformer stage.
72+
nested_loggers: `transformer_id`s of nested transformer stages which were called by
73+
the current stage.
74+
"""
75+
76+
transformer_id: int
77+
transformer_name: str
78+
initial_circuit: 'cirq.AbstractCircuit'
79+
final_circuit: 'cirq.AbstractCircuit'
80+
logs: List[Tuple[LogLevel, Tuple[str, ...]]] = dataclasses.field(default_factory=list)
81+
nested_loggers: List[int] = dataclasses.field(default_factory=list)
82+
83+
84+
class TransformerLogger:
85+
"""Base Class for transformer logging infrastructure. Defaults to text-based logging.
86+
87+
The logger implementation should be stateful, s.t.:
88+
- Each call to `register_initial` registers a new transformer stage and initial circuit.
89+
- Each subsequent call to `log` should store additional logs corresponding to the stage.
90+
- Each call to `register_final` should register the end of the currently active stage.
91+
92+
The logger assumes that
93+
- Transformers are run sequentially.
94+
- Nested transformers are allowed, in which case the behavior would be similar to a
95+
doing a depth first search on the graph of transformers -- i.e. the top level transformer
96+
would end (i.e. receive a `register_final` call) once all nested transformers (i.e. all
97+
`register_initial` calls received while the top level transformer was active) have
98+
finished (i.e. corresponding `register_final` calls have also been received).
99+
- This behavior can be simulated by maintaining a stack of currently active stages and
100+
adding data from `log` calls to the stage at the top of the stack.
101+
102+
The `LogLevel`s can be used to control the input processing and output resolution of the logs.
103+
"""
104+
105+
def __init__(self):
106+
"""Initializes TransformerLogger."""
107+
self._curr_id: int = 0
108+
self._logs: List[_LoggerNode] = []
109+
self._stack: List[int] = []
110+
111+
def register_initial(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None:
112+
"""Register the beginning of a new transformer stage.
113+
114+
Args:
115+
circuit: Input circuit to the new transformer stage.
116+
transformer_name: Name of the new transformer stage.
117+
"""
118+
if self._stack:
119+
self._logs[self._stack[-1]].nested_loggers.append(self._curr_id)
120+
self._logs.append(_LoggerNode(self._curr_id, transformer_name, circuit, circuit))
121+
self._stack.append(self._curr_id)
122+
self._curr_id += 1
123+
124+
def log(self, *args: str, level: LogLevel = LogLevel.INFO) -> None:
125+
"""Log additional metadata corresponding to the currently active transformer stage.
126+
127+
Args:
128+
*args: The additional metadata to log.
129+
level: Logging level to control the amount of metadata that gets put into the context.
130+
131+
Raises:
132+
ValueError: If there's no active transformer on the stack.
133+
"""
134+
if len(self._stack) == 0:
135+
raise ValueError('No active transformer found.')
136+
self._logs[self._stack[-1]].logs.append((level, args))
137+
138+
def register_final(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None:
139+
"""Register the end of the currently active transformer stage.
140+
141+
Args:
142+
circuit: Final transformed output circuit from the transformer stage.
143+
transformer_name: Name of the (currently active) transformer stage which ends.
144+
145+
Raises:
146+
ValueError: If `transformer_name` is different from currently active transformer name.
147+
"""
148+
tid = self._stack.pop()
149+
if self._logs[tid].transformer_name != transformer_name:
150+
raise ValueError(
151+
f"Expected `register_final` call for currently active transformer "
152+
f"{self._logs[tid].transformer_name}."
153+
)
154+
self._logs[tid].final_circuit = circuit
155+
156+
def show(self, level: LogLevel = LogLevel.INFO) -> None:
157+
"""Show the stored logs >= level in the desired format.
158+
159+
Args:
160+
level: The logging level to filter the logs with. The method shows all logs with a
161+
`LogLevel` >= `level`.
162+
"""
163+
164+
def print_log(log: _LoggerNode, pad=''):
165+
print(pad, f"Transformer-{1+log.transformer_id}: {log.transformer_name}", sep='')
166+
print(pad, "Initial Circuit:", sep='')
167+
print(textwrap.indent(str(log.initial_circuit), pad), "\n", sep='')
168+
for log_level, log_text in log.logs:
169+
if log_level.value >= level.value:
170+
print(pad, log_level, *log_text)
171+
print("\n", pad, "Final Circuit:", sep='')
172+
print(textwrap.indent(str(log.final_circuit), pad))
173+
print("----------------------------------------")
174+
175+
done = [0] * self._curr_id
176+
for i in range(self._curr_id):
177+
# Iterative DFS.
178+
stack = [(i, '')] if not done[i] else []
179+
while len(stack) > 0:
180+
log_id, pad = stack.pop()
181+
print_log(self._logs[log_id], pad)
182+
done[log_id] = True
183+
for child_id in self._logs[log_id].nested_loggers[::-1]:
184+
stack.append((child_id, pad + ' ' * 4))
185+
186+
187+
class NoOpTransformerLogger(TransformerLogger):
188+
"""All calls to this logger are a no-op"""
189+
190+
def register_initial(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None:
191+
pass
192+
193+
def log(self, *args: str, level: LogLevel = LogLevel.INFO) -> None:
194+
pass
195+
196+
def register_final(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None:
197+
pass
198+
199+
def show(self, level: LogLevel = LogLevel.INFO) -> None:
200+
pass
201+
202+
203+
@dataclasses.dataclass()
204+
class TransformerContext:
205+
"""Stores common configurable options for transformers.
206+
207+
Args:
208+
logger: `cirq.TransformerLogger` instance, which is a stateful logger used for logging
209+
the actions of individual transformer stages. The same logger instance should be
210+
shared across different transformer calls.
211+
ignore_tags: Tuple of tags which should be ignored while applying transformations on a
212+
circuit. Transformers should not transform any operation marked with a tag that
213+
belongs to this tuple. Note that any instance of a Hashable type (like `str`,
214+
`cirq.VirtualTag` etc.) is a valid tag.
215+
"""
216+
217+
logger: TransformerLogger = NoOpTransformerLogger()
218+
ignore_tags: Tuple[Hashable, ...] = ()
219+
220+
221+
TRANSFORMER = Callable[['cirq.AbstractCircuit', TransformerContext], 'cirq.AbstractCircuit']
222+
_TRANSFORMER_TYPE = Callable[['cirq.AbstractCircuit', TransformerContext], CIRCUIT_TYPE]
223+
224+
225+
def _transform_and_log(
226+
func: _TRANSFORMER_TYPE[CIRCUIT_TYPE],
227+
transformer_name: str,
228+
circuit: 'cirq.AbstractCircuit',
229+
context: TransformerContext,
230+
) -> CIRCUIT_TYPE:
231+
"""Helper to log initial and final circuits before and after calling the transformer."""
232+
233+
context.logger.register_initial(circuit, transformer_name)
234+
transformed_circuit = func(circuit, context)
235+
context.logger.register_final(transformed_circuit, transformer_name)
236+
return transformed_circuit
237+
238+
239+
def _transformer_class(
240+
cls: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
241+
) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]:
242+
old_func = cls.__call__
243+
244+
def transformer_with_logging_cls(
245+
self: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
246+
circuit: 'cirq.AbstractCircuit',
247+
context: TransformerContext,
248+
) -> CIRCUIT_TYPE:
249+
def call_old_func(c: 'cirq.AbstractCircuit', ct: TransformerContext) -> CIRCUIT_TYPE:
250+
return old_func(self, c, ct)
251+
252+
return _transform_and_log(call_old_func, cls.__name__, circuit, context)
253+
254+
setattr(cls, '__call__', transformer_with_logging_cls)
255+
return cls
256+
257+
258+
def _transformer_func(func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]:
259+
@functools.wraps(func)
260+
def transformer_with_logging_func(
261+
circuit: 'cirq.AbstractCircuit',
262+
context: TransformerContext,
263+
) -> CIRCUIT_TYPE:
264+
return _transform_and_log(func, func.__name__, circuit, context)
265+
266+
return transformer_with_logging_func
267+
268+
269+
@overload
270+
def transformer(cls_or_func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]:
271+
pass
272+
273+
274+
@overload
275+
def transformer(
276+
cls_or_func: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
277+
) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]:
278+
pass
279+
280+
281+
def transformer(cls_or_func: Any) -> Any:
282+
"""Decorator to verify API and append logging functionality to transformer functions & classes.
283+
284+
The decorated function or class must satisfy
285+
`Callable[[cirq.Circuit, cirq.TransformerContext], cirq.Circuit]` API. For Example:
286+
287+
>>> @cirq.transformer
288+
>>> def convert_to_cz(circuit: cirq.Circuit, context: cirq.TransformerContext) -> cirq.Circuit:
289+
>>> ...
290+
291+
The decorated class must implement the `__call__` method to satisfy the above API.
292+
293+
>>> @cirq.transformer
294+
>>> class ConvertToSqrtISwaps:
295+
>>> def __init__(self):
296+
>>> ...
297+
>>> def __call__(
298+
>>> self, circuit: cirq.Circuit, context: cirq.TransformerContext
299+
>>> ) -> cirq.Circuit:
300+
>>> ...
301+
302+
Args:
303+
cls_or_func: The callable class or method to be decorated.
304+
305+
Returns:
306+
Decorated class / method which includes additional logging boilerplate. The decorated
307+
callable always receives a copy of the input circuit so that the input is never mutated.
308+
"""
309+
if isinstance(cls_or_func, type):
310+
return _transformer_class(cls_or_func)
311+
else:
312+
assert callable(cls_or_func)
313+
return _transformer_func(cls_or_func)

0 commit comments

Comments
 (0)