From 2defbc9d77413fc8141eb9745e2f462c454e91a6 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Tue, 22 Apr 2025 16:03:44 +0300 Subject: [PATCH] refactor: compact down function calls --- src/grovers_visualizer/main.py | 64 +++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/src/grovers_visualizer/main.py b/src/grovers_visualizer/main.py index 181e3fb..13b5c38 100644 --- a/src/grovers_visualizer/main.py +++ b/src/grovers_visualizer/main.py @@ -9,9 +9,11 @@ using matplotlib. from collections.abc import Iterator from itertools import product from math import floor, pi, sqrt +from typing import Callable import matplotlib.pyplot as plt import numpy as np +import numpy.typing as npt from matplotlib.axes import Axes from matplotlib.container import BarContainer from qiskit import QuantumCircuit @@ -77,6 +79,20 @@ def all_states(n_qubits: int) -> Iterator[QubitState]: yield QubitState("".join(bits)) +def optimal_grover_iterations(n_qubits: int) -> int: + """Return the optimal number of Grover iterations for n qubits.""" + return floor(pi / 4 * sqrt(2**n_qubits)) + + +def get_bar_color(state: str, target_state: QubitState | None, iteration: int, optimal_iteration: int | None) -> str: + """Return the color for a bar based on state and iteration.""" + if state != target_state: + return "skyblue" + if optimal_iteration and iteration % optimal_iteration == 0 and iteration != 0: + return "green" + return "orange" + + def plot_amplitudes_live( ax: Axes, bars: BarContainer, @@ -87,25 +103,20 @@ def plot_amplitudes_live( target_state: QubitState | None = None, optimal_iteration: int | None = None, ) -> None: - amplitudes = statevector.data.real # Real part of amplitudes + amplitudes: npt.NDArray[np.float64] = statevector.data.real # Real part of amplitudes mean = np.mean(amplitudes) + for bar, state, amp in zip(bars, basis_states, amplitudes, strict=False): bar.set_height(amp) - if state == target_state: - if optimal_iteration is not None and iteration == optimal_iteration: - bar.set_color("green") - else: - bar.set_color("orange") - else: - bar.set_color("skyblue") + bar.set_color(get_bar_color(state, target_state, iteration, optimal_iteration)) + ax.set_title(f"Iteration {iteration}: {step_label}") ax.set_ylim(-1, 1) - # Remove previous mean line(s) - for l in ax.lines: + for l in ax.lines: # Remove previous mean line(s) l.remove() - ax.axhline(mean, color="red", linestyle="--", label="Mean") + ax.axhline(float(mean), color="red", linestyle="--", label="Mean") if not ax.get_legend(): ax.legend(loc="upper right") @@ -116,7 +127,7 @@ def main() -> None: target_state = QubitState("1010") n_qubits = len(target_state) basis_states = [str(bit) for bit in all_states(n_qubits)] - optimal_iterations = floor(pi / 4 * sqrt(2**n_qubits)) + optimal_iterations = optimal_grover_iterations(n_qubits) plt.ion() fig, ax = plt.subplots(figsize=(8, 3)) @@ -126,28 +137,25 @@ def main() -> None: ax.set_ylim(-1, 1) ax.set_title("Grover Amplitudes") + def step_and_plot( + operation: Callable[[QuantumCircuit], None] | None, + step_label: str, + iteration: int, + ) -> None: + if operation is not None: + operation(qc) + sv = Statevector.from_instruction(qc) + plot_amplitudes_live(ax, bars, sv, basis_states, step_label, iteration, target_state, optimal_iterations) + # Start with Hadamard qc = QuantumCircuit(n_qubits) qc.h(range(n_qubits)) - sv = Statevector.from_instruction(qc) - plot_amplitudes_live(ax, bars, sv, basis_states, "Hadamard (Initialization)", 0, target_state, optimal_iterations) + step_and_plot(None, "Hadamard (Initialization)", 0) iteration = 1 while plt.fignum_exists(fig.number): - # Oracle phase - oracle(qc, target_state) - sv = Statevector.from_instruction(qc) - plot_amplitudes_live( - ax, bars, sv, basis_states, "Oracle (Query Phase)", iteration, target_state, optimal_iterations - ) - - # Diffusion phase - diffusion(qc, n_qubits) - sv = Statevector.from_instruction(qc) - plot_amplitudes_live( - ax, bars, sv, basis_states, "Diffusion (Inversion Phase)", iteration, target_state, optimal_iterations - ) - + step_and_plot(lambda qc: oracle(qc, target_state), "Oracle (Query Phase)", iteration) + step_and_plot(lambda qc: diffusion(qc, n_qubits), "Diffusion (Inversion Phase)", iteration) iteration += 1 plt.ioff()