|
| 1 | +# Copyright 2022 The Cirq Developers |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Defines the API for circuit transformers in Cirq.""" |
| 16 | + |
| 17 | +import textwrap |
| 18 | +import functools |
| 19 | +from typing import ( |
| 20 | + Any, |
| 21 | + Callable, |
| 22 | + Tuple, |
| 23 | + Hashable, |
| 24 | + List, |
| 25 | + Type, |
| 26 | + overload, |
| 27 | + TYPE_CHECKING, |
| 28 | +) |
| 29 | +import dataclasses |
| 30 | +import enum |
| 31 | +from cirq.circuits.circuit import CIRCUIT_TYPE |
| 32 | + |
| 33 | +if TYPE_CHECKING: |
| 34 | + import cirq |
| 35 | + |
| 36 | + |
| 37 | +class LogLevel(enum.Enum): |
| 38 | + """Different logging resolution options for `cirq.TransformerLogger`. |
| 39 | +
|
| 40 | + The enum values of the logging levels are used to filter the stored logs when printing. |
| 41 | + In general, a logging level `X` includes all logs stored at a level >= 'X'. |
| 42 | +
|
| 43 | + Args: |
| 44 | + ALL: All levels. Used to filter logs when printing. |
| 45 | + DEBUG: Designates fine-grained informational events that are most useful to debug / |
| 46 | + understand in-depth any unexpected behavior of the transformer. |
| 47 | + INFO: Designates informational messages that highlight the actions of a transformer. |
| 48 | + WARNING: Designates unwanted or potentially harmful situations. |
| 49 | + NONE: No levels. Used to filter logs when printing. |
| 50 | + """ |
| 51 | + |
| 52 | + ALL = 0 |
| 53 | + DEBUG = 10 |
| 54 | + INFO = 20 |
| 55 | + WARNING = 30 |
| 56 | + NONE = 40 |
| 57 | + |
| 58 | + |
| 59 | +@dataclasses.dataclass |
| 60 | +class _LoggerNode: |
| 61 | + """Stores logging data of a single transformer stage in `cirq.TransformerLogger`. |
| 62 | +
|
| 63 | + The class is used to define a logging graph to store logs of sequential or nested transformers. |
| 64 | + Each node corresponds to logs of a single transformer stage. |
| 65 | +
|
| 66 | + Args: |
| 67 | + transformer_id: Integer specifying a unique id for corresponding transformer stage. |
| 68 | + transformer_name: Name of the corresponding transformer stage. |
| 69 | + initial_circuit: Initial circuit before the transformer stage began. |
| 70 | + final_circuit: Final circuit after the transformer stage ended. |
| 71 | + logs: Messages logged by the transformer stage. |
| 72 | + nested_loggers: `transformer_id`s of nested transformer stages which were called by |
| 73 | + the current stage. |
| 74 | + """ |
| 75 | + |
| 76 | + transformer_id: int |
| 77 | + transformer_name: str |
| 78 | + initial_circuit: 'cirq.AbstractCircuit' |
| 79 | + final_circuit: 'cirq.AbstractCircuit' |
| 80 | + logs: List[Tuple[LogLevel, Tuple[str, ...]]] = dataclasses.field(default_factory=list) |
| 81 | + nested_loggers: List[int] = dataclasses.field(default_factory=list) |
| 82 | + |
| 83 | + |
| 84 | +class TransformerLogger: |
| 85 | + """Base Class for transformer logging infrastructure. Defaults to text-based logging. |
| 86 | +
|
| 87 | + The logger implementation should be stateful, s.t.: |
| 88 | + - Each call to `register_initial` registers a new transformer stage and initial circuit. |
| 89 | + - Each subsequent call to `log` should store additional logs corresponding to the stage. |
| 90 | + - Each call to `register_final` should register the end of the currently active stage. |
| 91 | +
|
| 92 | + The logger assumes that |
| 93 | + - Transformers are run sequentially. |
| 94 | + - Nested transformers are allowed, in which case the behavior would be similar to a |
| 95 | + doing a depth first search on the graph of transformers -- i.e. the top level transformer |
| 96 | + would end (i.e. receive a `register_final` call) once all nested transformers (i.e. all |
| 97 | + `register_initial` calls received while the top level transformer was active) have |
| 98 | + finished (i.e. corresponding `register_final` calls have also been received). |
| 99 | + - This behavior can be simulated by maintaining a stack of currently active stages and |
| 100 | + adding data from `log` calls to the stage at the top of the stack. |
| 101 | +
|
| 102 | + The `LogLevel`s can be used to control the input processing and output resolution of the logs. |
| 103 | + """ |
| 104 | + |
| 105 | + def __init__(self): |
| 106 | + """Initializes TransformerLogger.""" |
| 107 | + self._curr_id: int = 0 |
| 108 | + self._logs: List[_LoggerNode] = [] |
| 109 | + self._stack: List[int] = [] |
| 110 | + |
| 111 | + def register_initial(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None: |
| 112 | + """Register the beginning of a new transformer stage. |
| 113 | +
|
| 114 | + Args: |
| 115 | + circuit: Input circuit to the new transformer stage. |
| 116 | + transformer_name: Name of the new transformer stage. |
| 117 | + """ |
| 118 | + if self._stack: |
| 119 | + self._logs[self._stack[-1]].nested_loggers.append(self._curr_id) |
| 120 | + self._logs.append(_LoggerNode(self._curr_id, transformer_name, circuit, circuit)) |
| 121 | + self._stack.append(self._curr_id) |
| 122 | + self._curr_id += 1 |
| 123 | + |
| 124 | + def log(self, *args: str, level: LogLevel = LogLevel.INFO) -> None: |
| 125 | + """Log additional metadata corresponding to the currently active transformer stage. |
| 126 | +
|
| 127 | + Args: |
| 128 | + *args: The additional metadata to log. |
| 129 | + level: Logging level to control the amount of metadata that gets put into the context. |
| 130 | +
|
| 131 | + Raises: |
| 132 | + ValueError: If there's no active transformer on the stack. |
| 133 | + """ |
| 134 | + if len(self._stack) == 0: |
| 135 | + raise ValueError('No active transformer found.') |
| 136 | + self._logs[self._stack[-1]].logs.append((level, args)) |
| 137 | + |
| 138 | + def register_final(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None: |
| 139 | + """Register the end of the currently active transformer stage. |
| 140 | +
|
| 141 | + Args: |
| 142 | + circuit: Final transformed output circuit from the transformer stage. |
| 143 | + transformer_name: Name of the (currently active) transformer stage which ends. |
| 144 | +
|
| 145 | + Raises: |
| 146 | + ValueError: If `transformer_name` is different from currently active transformer name. |
| 147 | + """ |
| 148 | + tid = self._stack.pop() |
| 149 | + if self._logs[tid].transformer_name != transformer_name: |
| 150 | + raise ValueError( |
| 151 | + f"Expected `register_final` call for currently active transformer " |
| 152 | + f"{self._logs[tid].transformer_name}." |
| 153 | + ) |
| 154 | + self._logs[tid].final_circuit = circuit |
| 155 | + |
| 156 | + def show(self, level: LogLevel = LogLevel.INFO) -> None: |
| 157 | + """Show the stored logs >= level in the desired format. |
| 158 | +
|
| 159 | + Args: |
| 160 | + level: The logging level to filter the logs with. The method shows all logs with a |
| 161 | + `LogLevel` >= `level`. |
| 162 | + """ |
| 163 | + |
| 164 | + def print_log(log: _LoggerNode, pad=''): |
| 165 | + print(pad, f"Transformer-{1+log.transformer_id}: {log.transformer_name}", sep='') |
| 166 | + print(pad, "Initial Circuit:", sep='') |
| 167 | + print(textwrap.indent(str(log.initial_circuit), pad), "\n", sep='') |
| 168 | + for log_level, log_text in log.logs: |
| 169 | + if log_level.value >= level.value: |
| 170 | + print(pad, log_level, *log_text) |
| 171 | + print("\n", pad, "Final Circuit:", sep='') |
| 172 | + print(textwrap.indent(str(log.final_circuit), pad)) |
| 173 | + print("----------------------------------------") |
| 174 | + |
| 175 | + done = [0] * self._curr_id |
| 176 | + for i in range(self._curr_id): |
| 177 | + # Iterative DFS. |
| 178 | + stack = [(i, '')] if not done[i] else [] |
| 179 | + while len(stack) > 0: |
| 180 | + log_id, pad = stack.pop() |
| 181 | + print_log(self._logs[log_id], pad) |
| 182 | + done[log_id] = True |
| 183 | + for child_id in self._logs[log_id].nested_loggers[::-1]: |
| 184 | + stack.append((child_id, pad + ' ' * 4)) |
| 185 | + |
| 186 | + |
| 187 | +class NoOpTransformerLogger(TransformerLogger): |
| 188 | + """All calls to this logger are a no-op""" |
| 189 | + |
| 190 | + def register_initial(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None: |
| 191 | + pass |
| 192 | + |
| 193 | + def log(self, *args: str, level: LogLevel = LogLevel.INFO) -> None: |
| 194 | + pass |
| 195 | + |
| 196 | + def register_final(self, circuit: 'cirq.AbstractCircuit', transformer_name: str) -> None: |
| 197 | + pass |
| 198 | + |
| 199 | + def show(self, level: LogLevel = LogLevel.INFO) -> None: |
| 200 | + pass |
| 201 | + |
| 202 | + |
| 203 | +@dataclasses.dataclass() |
| 204 | +class TransformerContext: |
| 205 | + """Stores common configurable options for transformers. |
| 206 | +
|
| 207 | + Args: |
| 208 | + logger: `cirq.TransformerLogger` instance, which is a stateful logger used for logging |
| 209 | + the actions of individual transformer stages. The same logger instance should be |
| 210 | + shared across different transformer calls. |
| 211 | + ignore_tags: Tuple of tags which should be ignored while applying transformations on a |
| 212 | + circuit. Transformers should not transform any operation marked with a tag that |
| 213 | + belongs to this tuple. Note that any instance of a Hashable type (like `str`, |
| 214 | + `cirq.VirtualTag` etc.) is a valid tag. |
| 215 | + """ |
| 216 | + |
| 217 | + logger: TransformerLogger = NoOpTransformerLogger() |
| 218 | + ignore_tags: Tuple[Hashable, ...] = () |
| 219 | + |
| 220 | + |
| 221 | +TRANSFORMER = Callable[['cirq.AbstractCircuit', TransformerContext], 'cirq.AbstractCircuit'] |
| 222 | +_TRANSFORMER_TYPE = Callable[['cirq.AbstractCircuit', TransformerContext], CIRCUIT_TYPE] |
| 223 | + |
| 224 | + |
| 225 | +def _transform_and_log( |
| 226 | + func: _TRANSFORMER_TYPE[CIRCUIT_TYPE], |
| 227 | + transformer_name: str, |
| 228 | + circuit: 'cirq.AbstractCircuit', |
| 229 | + context: TransformerContext, |
| 230 | +) -> CIRCUIT_TYPE: |
| 231 | + """Helper to log initial and final circuits before and after calling the transformer.""" |
| 232 | + |
| 233 | + context.logger.register_initial(circuit, transformer_name) |
| 234 | + transformed_circuit = func(circuit, context) |
| 235 | + context.logger.register_final(transformed_circuit, transformer_name) |
| 236 | + return transformed_circuit |
| 237 | + |
| 238 | + |
| 239 | +def _transformer_class( |
| 240 | + cls: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]], |
| 241 | +) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]: |
| 242 | + old_func = cls.__call__ |
| 243 | + |
| 244 | + def transformer_with_logging_cls( |
| 245 | + self: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]], |
| 246 | + circuit: 'cirq.AbstractCircuit', |
| 247 | + context: TransformerContext, |
| 248 | + ) -> CIRCUIT_TYPE: |
| 249 | + def call_old_func(c: 'cirq.AbstractCircuit', ct: TransformerContext) -> CIRCUIT_TYPE: |
| 250 | + return old_func(self, c, ct) |
| 251 | + |
| 252 | + return _transform_and_log(call_old_func, cls.__name__, circuit, context) |
| 253 | + |
| 254 | + setattr(cls, '__call__', transformer_with_logging_cls) |
| 255 | + return cls |
| 256 | + |
| 257 | + |
| 258 | +def _transformer_func(func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]: |
| 259 | + @functools.wraps(func) |
| 260 | + def transformer_with_logging_func( |
| 261 | + circuit: 'cirq.AbstractCircuit', |
| 262 | + context: TransformerContext, |
| 263 | + ) -> CIRCUIT_TYPE: |
| 264 | + return _transform_and_log(func, func.__name__, circuit, context) |
| 265 | + |
| 266 | + return transformer_with_logging_func |
| 267 | + |
| 268 | + |
| 269 | +@overload |
| 270 | +def transformer(cls_or_func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]: |
| 271 | + pass |
| 272 | + |
| 273 | + |
| 274 | +@overload |
| 275 | +def transformer( |
| 276 | + cls_or_func: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]], |
| 277 | +) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]: |
| 278 | + pass |
| 279 | + |
| 280 | + |
| 281 | +def transformer(cls_or_func: Any) -> Any: |
| 282 | + """Decorator to verify API and append logging functionality to transformer functions & classes. |
| 283 | +
|
| 284 | + The decorated function or class must satisfy |
| 285 | + `Callable[[cirq.Circuit, cirq.TransformerContext], cirq.Circuit]` API. For Example: |
| 286 | +
|
| 287 | + >>> @cirq.transformer |
| 288 | + >>> def convert_to_cz(circuit: cirq.Circuit, context: cirq.TransformerContext) -> cirq.Circuit: |
| 289 | + >>> ... |
| 290 | +
|
| 291 | + The decorated class must implement the `__call__` method to satisfy the above API. |
| 292 | +
|
| 293 | + >>> @cirq.transformer |
| 294 | + >>> class ConvertToSqrtISwaps: |
| 295 | + >>> def __init__(self): |
| 296 | + >>> ... |
| 297 | + >>> def __call__( |
| 298 | + >>> self, circuit: cirq.Circuit, context: cirq.TransformerContext |
| 299 | + >>> ) -> cirq.Circuit: |
| 300 | + >>> ... |
| 301 | +
|
| 302 | + Args: |
| 303 | + cls_or_func: The callable class or method to be decorated. |
| 304 | +
|
| 305 | + Returns: |
| 306 | + Decorated class / method which includes additional logging boilerplate. The decorated |
| 307 | + callable always receives a copy of the input circuit so that the input is never mutated. |
| 308 | + """ |
| 309 | + if isinstance(cls_or_func, type): |
| 310 | + return _transformer_class(cls_or_func) |
| 311 | + else: |
| 312 | + assert callable(cls_or_func) |
| 313 | + return _transformer_func(cls_or_func) |
0 commit comments