diff --git a/example_script.py b/example_script.py index 7376018..50171a9 100644 --- a/example_script.py +++ b/example_script.py @@ -16,7 +16,7 @@ n_epochs = 500 # Tolerance after which to interrupt a given compilation -tol_interrupt = 1e-10 +tol = 1e-10 # Number of datapoints to collect during optimization. Needs to be a divisor of n_epochs. # Due to a bug, this currently needs to be less than n_epochs. @@ -26,43 +26,20 @@ # This number is re-processed below to take num_records into account. max_const = 50 -# Get dimension of the group to use for setting memory of optimizer -d, *_ = unicirc.ansatz_specs(n, group) -memory_size = 5 * d # Generous memory budget -learning_rate = None # L-BFGS uses line search, learning rate is not required - -# 3. Create cost function, optimizer, and optimization executable -matrix_fn = unicirc.matrix_v2(n, group) -cost_fn = unicirc.make_cost_fn(matrix_fn) -optimizer = optax.lbfgs(learning_rate=None, memory_size=5 * d) -run_optimization = unicirc.make_optimization_run(cost_fn, optimizer) - - -# Reprocess max_const to take num_records into account -_epochs_per_record = n_epochs // num_records -max_const = max((2, max_const // _epochs_per_record)) - -opt_run = jax.jit( - partial( - run_optimization, - n_epochs=n_epochs, - tol=tol_interrupt, - max_const=max_const, - progress_bar=True, - num_records=num_records, - ) +costs, thetas, successful = unicirc.compile( + target, + group, + key=seed, + tol=tol, + n_epochs=n_epochs, + max_const=max_const, + break_at_success=True, ) -target_dag = target.conj().T -# Setting tol here implies that the compiler stops after an attempt has converged to tol_interrupt -# For tol=None, the compiler just attempts compilation `max_attempts` times. -energies, thetas, successful = unicirc.compile(d, partial(opt_run, target_dag=target_dag), key=seed, tol=tol_interrupt, max_attempts=10) - -print(f"Compiler was successful: {successful}") -for idx, energy_attempt in enumerate(energies): - plt.plot(*energy_attempt.T, label=f"Attempt {idx}") +for idx, cost_per_attempt in enumerate(costs): + plt.plot(*cost_per_attempt.T, label=f"Attempt {idx}") -plt.plot([0, n_epochs], [tol_interrupt]*2, ls=":", color="k") +plt.plot([0, n_epochs], [tol]*2, ls=":", color="k") plt.legend() plt.yscale("log") plt.show() diff --git a/unicirc/__init__.py b/unicirc/__init__.py index 1870dc5..aa5950e 100644 --- a/unicirc/__init__.py +++ b/unicirc/__init__.py @@ -13,6 +13,7 @@ make_optimization_run, compile_adapt, sample_from_group, + repeated_optimization, ) from .count_clifford import count_clifford from .universality_test import ( diff --git a/unicirc/optimization.py b/unicirc/optimization.py index 5312197..283e70c 100644 --- a/unicirc/optimization.py +++ b/unicirc/optimization.py @@ -8,7 +8,7 @@ from jax import numpy as jnp import optax from jax_tqdm import loop_tqdm -from .matrix import matrix_v2_partial +from .matrix import matrix_v2_partial, matrix_v2 from .universal_ansatze import ansatz_specs from scipy.stats import unitary_group, ortho_group, special_ortho_group @@ -39,6 +39,17 @@ def cost_fn(params, target_dag): return cost_fn +def _make_cost_fn(matrix_fn, target): + """Produce a cost function for compilation for a given matrix function.""" + + target_dag = target.conj().T + + def cost_fn(params): + U = matrix_fn(params) + return 1 - jnp.abs(jnp.trace(target_dag @ U)) / len(target_dag) + + return cost_fn + def make_optimization_run(cost_fn, optimizer): """Create a full optimization workflow executable, with tailored syntax for @@ -55,18 +66,17 @@ def make_optimization_run(cost_fn, optimizer): compiled_cost_fn = jax.jit(cost_fn) @jax.jit - def partial_step(opt_state, theta, last_val, target_dag): + def partial_step(opt_state, theta, last_val): """Closure variables: value_and_grad optimizer compiled_cost_fn """ - val, grad_circuit = value_and_grad(theta, target_dag) + val, grad_circuit = value_and_grad(theta) updates, opt_state = optimizer.update( grad_circuit, opt_state, theta, - target_dag=target_dag, value=val, grad=grad_circuit, value_fn=compiled_cost_fn, @@ -75,7 +85,7 @@ def partial_step(opt_state, theta, last_val, target_dag): return opt_state, theta, val @jax.jit - def static_step(opt_state, theta, last_val, target_dag): + def static_step(opt_state, theta, last_val): return opt_state, theta, last_val def _or(a, b): @@ -86,7 +96,7 @@ def _and(a, b): """Compute logical and in a JIT compatible manner.""" return jax.lax.cond(a, lambda b: b, lambda b: False, b) - def optimization_step(i, values, tol, max_const, target_dag, record_mod=1): + def optimization_step(i, values, tol, max_const, record_mod=1): """Perform a step of an optimization process. Args: @@ -98,7 +108,6 @@ def optimization_step(i, values, tol, max_const, target_dag, record_mod=1): optimization is effectively interrupted. max_const (int): Maximal number of iterations for which a constant cost value is allowed. Afterwards, the optimization is effectively interrupted. - target_dag (jnp.ndarray): Adjoint of the compilation target. record_mod (int): Interval at which to record the cost function values. Returns: @@ -109,9 +118,9 @@ def optimization_step(i, values, tol, max_const, target_dag, record_mod=1): opt_state, theta, cost, last_val, rec_idx = values if (last_val < tol) or (i>max_const>0 and allclose(cost[i-max_const:i-1], last_val)): - opt_state, theta, val = static_step(opt_state, theta, last_val, target_dag) + opt_state, theta, val = static_step(opt_state, theta, last_val) else: - opt_state, theta, val = partial_step(opt_state, theta, last_val, target_dag) + opt_state, theta, val = partial_step(opt_state, theta, last_val) if i % record_mod == 0: cost[rec_idx] = val @@ -137,7 +146,6 @@ def optimization_step(i, values, tol, max_const, target_dag, record_mod=1): opt_state, theta, last_val, - target_dag, ) cost, rec_idx = jax.lax.cond( i % record_mod == 0, @@ -151,13 +159,12 @@ def optimization_step(i, values, tol, max_const, target_dag, record_mod=1): return opt_state, theta, cost, val, rec_idx def run_optimization( - init_params, target_dag, n_epochs, tol, max_const, progress_bar, num_records=None + init_params, n_epochs, tol, max_const, progress_bar, num_records=None ): """Run optimization workflow based on ``cost_fn`` and ``optimizer``. Args: init_params (jnp.ndarray): Initial parameters to be optimized. - target_dag (jnp.ndarray): Adjoint of target unitary to be compiled. n_epochs (int): Number of epochs to optimize for. tol (float): Tolerance below which the optimization is effectively interrupted. max_const (int): Number of stagnating cost **recordings** (not optimization steps) @@ -180,7 +187,6 @@ def run_optimization( optimization_step, tol=tol, max_const=max_const, - target_dag=target_dag, record_mod=record_mod, ) if progress_bar: @@ -192,7 +198,7 @@ def run_optimization( return run_optimization -def compile(dim, optimization_run, key=None, tol=1e-10, max_attempts=10): +def repeated_optimization(dim, optimization_run, key=None, tol=1e-10, max_attempts=10): """Repeatedly execute an optimization run until a convergence tolerance is hit. Args: @@ -235,6 +241,62 @@ def compile(dim, optimization_run, key=None, tol=1e-10, max_attempts=10): return costs, thetas, successful +def compile(target, group, key=None, tol=1e-10, n_epochs=1000, max_attempts=10, learning_rate=None, memory_size=None, max_const=100, num_records=None, progress_bar=True, break_at_success=True, print_report=True): + N = np.shape(target)[1] + n = int(np.round(np.log2(N))) + d, *_ = ansatz_specs(n, group) + matrix_fn = matrix_v2(n, group) + cost_fn = _make_cost_fn(matrix_fn, target) + memory_size = memory_size or 5 * d + optimizer = optax.lbfgs(learning_rate=learning_rate, memory_size=memory_size) + optimization_run = make_optimization_run(cost_fn, optimizer) + + if num_records is None: + num_records = n_epochs // 2 + _epochs_per_record = n_epochs // num_records + max_const = max((2, max_const // _epochs_per_record)) + opt_run = jax.jit(partial(optimization_run, n_epochs=n_epochs, tol=tol, max_const=max_const, progress_bar=progress_bar, num_records=num_records)) + + if key is None: + key = jax.random.PRNGKey(np.random.randint(24125)) + elif isinstance(key, int): + key = jax.random.PRNGKey(key) + + costs = [] + thetas = [] + successful = [] + for _ in range(max_attempts): + key, use_key = jax.random.split(key) + # 0.2 is a good scaling factor for SU on n<=5 qubits + theta = jax.random.normal(use_key, (d,)) * 0.2 + theta, cost = opt_run(theta) + costs.append(cost) + thetas.append(theta) + if tol is not None: + if np.min(cost[:, 1]) <= tol: + successful.append(True) + if break_at_success: + break + else: + successful.append(False) + + if print_report: + group_label = "Sp*" if group=="Sp" else group + main_stat = f"\nThe compiler for {group_label}({N}) ran {len(costs)}/{max_attempts} attempts" + if tol is not None: + if break_at_success: + main_stat += f" before succeeding." + else: + num_success = sum(successful) + main_stat += f" out of which {num_success} were successful." + print(main_stat) + print(f"Used hyperparameters:\n {tol=}\n {n_epochs=}\n {learning_rate=}\n {memory_size=}\n {max_const=}") + + return costs, thetas, successful + + + + def compile_adapt( target,