refactor(cli): refactor CLI into modular simulation and visualization components

- Extract Grover iteration logic into simulation.py (grover_evolver generator)
- Move all Matplotlib setup/update/teardown into visualization.py (GroverVisualizer)
This commit is contained in:
Kristofers Solo 2025-05-08 14:11:18 +03:00
parent 28a9bc6ab0
commit f4b99262ec
Signed by: kristoferssolo
GPG Key ID: 74FF8144483D82C8
13 changed files with 269 additions and 221 deletions

View File

@ -1,6 +1,6 @@
[project]
name = "grovers-visualizer"
version = "0.4.2"
version = "0.4.3"
description = "A tiny Python package that steps through Grovers Search algorithm."
readme = "README.md"
requires-python = ">=3.10"

View File

@ -1,28 +1,8 @@
from math import floor, pi, sqrt
from qiskit import QuantumCircuit
from .state import QubitState
def grover_search(target_state: QubitState, iterations: int | None = None) -> QuantumCircuit:
"""Construct a Grover search circuit for the given target state."""
n = len(target_state)
qc = QuantumCircuit(n, n)
qc.h(range(n))
if iterations is None or iterations < 0:
iterations = floor(pi / 4 * sqrt(2**n))
for _ in range(iterations):
oracle(qc, target_state)
diffusion(qc, n)
qc.measure(range(n), range(n))
return qc
def oracle(qc: QuantumCircuit, target_state: QubitState) -> None:
"""Oracle that flips the sign of the target state."""
n = len(target_state)

View File

@ -1,75 +1,15 @@
from math import asin, sqrt
from typing import Callable
import matplotlib.pyplot as plt
from matplotlib.backend_bases import KeyEvent
from matplotlib.gridspec import GridSpec
from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector
from grovers_visualizer.simulation import grover_evolver
from grovers_visualizer.visualization import GroverVisualizer
from .args import Args
from .circuit import diffusion, oracle
from .plot import SinePlotData, plot_amplitudes, plot_circle, plot_sine
from .utils import all_states, optimal_grover_iterations
def run_cli(args: Args) -> None:
target_state = args.target
n_qubits = len(target_state)
basis_states = [str(bit) for bit in all_states(n_qubits)]
optimal_iterations = optimal_grover_iterations(n_qubits)
theta = 2 * asin(1 / sqrt(2**n_qubits))
state_angle = 0.5 * theta
vis = GroverVisualizer(args.target, pause=args.speed)
plt.ion()
fig = plt.figure(figsize=(14, 6))
gs = GridSpec(2, 2, width_ratios=(3, 1), figure=fig)
ax_bar = fig.add_subplot(gs[0, 0])
ax_sine = fig.add_subplot(gs[1, 0])
ax_circle = fig.add_subplot(gs[:, 1])
bars = ax_bar.bar(basis_states, [0] * len(basis_states), color="skyblue")
ax_bar.set_ylim(-1, 1)
ax_bar.set_title("Amplitudes (example)")
sine_data = SinePlotData()
def plot_bar(
operation: Callable[[QuantumCircuit], None] | None,
step_label: str,
iteration: int,
) -> None:
if operation is not None:
operation(qc)
sv = Statevector.from_instruction(qc)
plot_amplitudes(ax_bar, bars, sv, basis_states, step_label, iteration, target_state, optimal_iterations)
# Start with Hadamard
qc = QuantumCircuit(n_qubits)
qc.h(range(n_qubits))
plot_bar(None, "Hadamard (Initialization)", 0)
iteration = 1
running = True
def on_key(event: KeyEvent) -> None:
nonlocal running
if event.key == "q":
running = False
cid = fig.canvas.mpl_connect("key_press_event", on_key)
while plt.fignum_exists(fig.number) and running:
plot_bar(lambda qc: oracle(qc, target_state), "Oracle (Query Phase)", iteration)
plot_bar(lambda qc: diffusion(qc, n_qubits), "Diffusion (Inversion Phase)", iteration)
plot_circle(ax_circle, iteration, optimal_iterations, theta, state_angle)
sine_data.calc_and_append_probability(iteration, theta)
plot_sine(ax_sine, sine_data)
plt.pause(args.speed)
iteration += 1
if args.iterations > 0 and iteration > args.iterations:
for it, sv in grover_evolver(vis.target, args.iterations):
if not vis.is_running:
break
vis.update(it, sv)
fig.canvas.mpl_disconnect(cid)
plt.ioff()
vis.finalize()

View File

@ -1,129 +0,0 @@
from dataclasses import dataclass, field
from math import cos, sin
import numpy as np
import numpy.typing as npt
from matplotlib.axes import Axes
from matplotlib.container import BarContainer
from matplotlib.patches import Circle
from qiskit.quantum_info import Statevector
from .state import QubitState
from .utils import get_bar_color, is_optimal_iteration
def plot_amplitudes(
ax: Axes,
bars: BarContainer,
statevector: Statevector,
basis_states: list[str],
iteration_label: str,
iteration: int,
target_state: QubitState | None = None,
optimal_iteration: int | None = None,
) -> None:
amplitudes: npt.NDArray[np.float64] = statevector.data.real # Real part of amplitudes
mean = np.mean(amplitudes)
for bar, state, amp in zip(bars, basis_states, amplitudes, strict=False):
bar.set_height(amp)
bar.set_color(get_bar_color(state, target_state, iteration, optimal_iteration))
ax.set_title(f"Iteration {iteration}: {iteration_label}")
ax.set_ylim(-1, 1)
for l in ax.lines: # Remove previous mean line(s)
l.remove()
# Draw axes and mean
ax.axhline(0, color="black", linewidth=0.5)
ax.axhline(float(mean), color="red", linestyle="--", label="Mean")
if not ax.get_legend():
ax.legend(loc="upper right")
def plot_circle(
ax: Axes,
iteration: int,
optimal_iterations: int,
theta: float,
state_angle: float,
) -> None:
ax.clear()
ax.set_aspect("equal")
ax.set_xlim(-1.1, 1.1)
ax.set_ylim(-1.1, 1.1)
ax.set_xlabel("Unmarked amplitude")
ax.set_ylabel("Target amplitude")
ax.set_title("Grover State Vector Rotation")
# Draw unit circle
circle = Circle((0, 0), 1, color="gray", fill=False)
ax.add_artist(circle)
# Draw axes
ax.axhline(0, color="black", linewidth=0.5)
ax.axvline(0, color="black", linewidth=0.5)
# Draw labels
ax.text(1.05, 0, "", va="center", ha="left", fontsize=10)
ax.text(0, 1.05, "1", va="bottom", ha="center", fontsize=10)
ax.text(-1.05, 0, "", va="center", ha="right", fontsize=10)
ax.text(0, -1.05, "-1", va="top", ha="center", fontsize=10)
angle = state_angle + iteration * theta
x, y = cos(angle), sin(angle)
is_optimal = is_optimal_iteration(iteration, optimal_iterations)
# Arrow color: green at optimal, blue otherwise
color = "green" if is_optimal else "blue"
ax.arrow(0, 0, x, y, head_width=0.07, head_length=0.1, fc=color, ec=color, length_includes_head=True)
# Probability of target state is y^2
prob = y**2
# Draw the value at the tip of the arrow
ax.text(
x,
y,
f"{prob:.2f}",
color=color,
fontsize=10,
ha="left" if x >= 0 else "right",
va="bottom" if y >= 0 else "top",
fontweight="bold",
bbox={"facecolor": "white", "edgecolor": "none", "alpha": 0.7, "boxstyle": "round,pad=0.2"},
)
ax.set_title(
f"Grover State Vector Rotation\nIteration {iteration} | Probability of target: {prob}{' (optimal)' if is_optimal else ''}"
)
@dataclass
class SinePlotData:
x: list[float] = field(default_factory=list)
y: list[float] = field(default_factory=list)
def append(self, x: float, y: float) -> None:
self.x.append(x)
self.y.append(y)
def calc_and_append_probability(self, iteration: int, theta: float) -> None:
prob = sin((2 * iteration + 1) * theta / 2) ** 2
self.append(iteration, prob)
def plot_sine(
ax: Axes,
sine_data: SinePlotData,
) -> None:
ax.clear()
ax.plot(sine_data.x, sine_data.y, marker="o", color="purple", label="Target Probability")
ax.set_xlabel("Iteration")
ax.set_ylabel("Probability")
ax.set_title("Grover Target Probability vs. Iteration")
ax.set_ylim(0, 1)
ax.set_xlim(0, max(10, max(sine_data.x) + 1))
ax.legend()

View File

@ -0,0 +1,5 @@
from .amplitudes import plot_amplitudes
from .circle import plot_circle
from .sine import SinePlotData, plot_sine
__all__ = ("SinePlotData", "plot_amplitudes", "plot_circle", "plot_sine")

View File

@ -0,0 +1,39 @@
import numpy as np
import numpy.typing as npt
from matplotlib.axes import Axes
from matplotlib.container import BarContainer
from qiskit.quantum_info import Statevector
from grovers_visualizer.state import QubitState
from grovers_visualizer.utils import get_bar_color
def plot_amplitudes(
ax: Axes,
bars: BarContainer,
statevector: Statevector,
basis_states: list[str],
iteration_label: str,
iteration: int,
target_state: QubitState | None = None,
optimal_iteration: int | None = None,
) -> None:
amplitudes: npt.NDArray[np.float64] = statevector.data.real # Real part of amplitudes
mean = np.mean(amplitudes)
for bar, state, amp in zip(bars, basis_states, amplitudes, strict=False):
bar.set_height(amp)
bar.set_color(get_bar_color(state, target_state, iteration, optimal_iteration))
ax.set_title(f"Iteration {iteration}: {iteration_label}")
ax.set_ylim(-1, 1)
for l in ax.lines: # Remove previous mean line(s)
l.remove()
# Draw axes and mean
ax.axhline(0, color="black", linewidth=0.5)
ax.axhline(float(mean), color="red", linestyle="--", label="Mean")
if not ax.get_legend():
ax.legend(loc="upper right")

View File

@ -0,0 +1,64 @@
from math import cos, sin
from matplotlib.axes import Axes
from matplotlib.patches import Circle
from grovers_visualizer.utils import is_optimal_iteration
def plot_circle(
ax: Axes,
iteration: int,
optimal_iterations: int,
theta: float,
state_angle: float,
) -> None:
ax.clear()
ax.set_aspect("equal")
ax.set_xlim(-1.1, 1.1)
ax.set_ylim(-1.1, 1.1)
ax.set_xlabel("Unmarked amplitude")
ax.set_ylabel("Target amplitude")
ax.set_title("Grover State Vector Rotation")
# Draw unit circle
circle = Circle((0, 0), 1, color="gray", fill=False)
ax.add_artist(circle)
# Draw axes
ax.axhline(0, color="black", linewidth=0.5)
ax.axvline(0, color="black", linewidth=0.5)
# Draw labels
ax.text(1.05, 0, "", va="center", ha="left", fontsize=10)
ax.text(0, 1.05, "1", va="bottom", ha="center", fontsize=10)
ax.text(-1.05, 0, "", va="center", ha="right", fontsize=10)
ax.text(0, -1.05, "-1", va="top", ha="center", fontsize=10)
angle = state_angle + iteration * theta
x, y = cos(angle), sin(angle)
is_optimal = is_optimal_iteration(iteration, optimal_iterations)
# Arrow color: green at optimal, blue otherwise
color = "green" if is_optimal else "blue"
ax.arrow(0, 0, x, y, head_width=0.07, head_length=0.1, fc=color, ec=color, length_includes_head=True)
# Probability of target state is y^2
prob = y**2
# Draw the value at the tip of the arrow
ax.text(
x,
y,
f"{prob:.2f}",
color=color,
fontsize=10,
ha="left" if x >= 0 else "right",
va="bottom" if y >= 0 else "top",
fontweight="bold",
bbox={"facecolor": "white", "edgecolor": "none", "alpha": 0.7, "boxstyle": "round,pad=0.2"},
)
ax.set_title(
f"Grover State Vector Rotation\nIteration {iteration} | Probability of target: {prob}{' (optimal)' if is_optimal else ''}"
)

View File

@ -0,0 +1,32 @@
from dataclasses import dataclass, field
from math import sin
from matplotlib.axes import Axes
@dataclass
class SinePlotData:
x: list[float] = field(default_factory=list)
y: list[float] = field(default_factory=list)
def append(self, x: float, y: float) -> None:
self.x.append(x)
self.y.append(y)
def calc_and_append_probability(self, iteration: int, theta: float) -> None:
prob = sin((2 * iteration + 1) * theta / 2) ** 2
self.append(iteration, prob)
def plot_sine(
ax: Axes,
sine_data: SinePlotData,
) -> None:
ax.clear()
ax.plot(sine_data.x, sine_data.y, marker="o", color="purple", label="Target Probability")
ax.set_xlabel("Iteration")
ax.set_ylabel("Probability")
ax.set_title("Grover Target Probability vs. Iteration")
ax.set_ylim(0, 1)
ax.set_xlim(0, max(10, max(sine_data.x) + 1))
ax.legend()

View File

@ -0,0 +1,31 @@
from collections.abc import Iterator
from itertools import count
from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector
from grovers_visualizer.circuit import diffusion, oracle
from grovers_visualizer.state import QubitState
def grover_evolver(target: QubitState, max_iterations: int = 0) -> Iterator[tuple[int, Statevector]]:
"""Yields (iteration, statevector) pairs.
iteration=0 is the uniform-Hadamard initialization. If
max_iterations > 0, stop after that many iterations. If
max_iterations == 0, run indefinitely (until the consumer breaks).
"""
n_qubits = len(target)
qc = QuantumCircuit(n_qubits)
qc.h(range(n_qubits))
# initial statevector
yield 0, Statevector.from_instruction(qc)
# pick an iterator for subsequent steps
iter = range(1, max_iterations + 1) if max_iterations > 0 else count(1)
for i in iter:
oracle(qc, target)
diffusion(qc, n_qubits)
yield i, Statevector.from_instruction(qc)

View File

@ -10,11 +10,11 @@ def is_dearpygui_available() -> bool:
return False
def run_dpg_ui(args: Args) -> None:
def run_dpg_ui(_args: Args) -> None:
if not is_dearpygui_available():
print("DearPyGui is not installed. Install with: pip install 'grovers-visualizer[ui]'")
return
from .dpg import run_dearpygui_ui
run_dearpygui_ui(args)
run_dearpygui_ui()

View File

@ -1,7 +1,9 @@
import dearpygui.dearpygui as dpg
from grovers_visualizer.args import Args
def run_dearpygui_ui() -> None:
def run_dearpygui_ui(_args: Args) -> None:
dpg.create_context()
dpg.create_viewport(title="Grover's Search Visualizer", width=900, height=600)
dpg.setup_dearpygui()

View File

@ -14,7 +14,7 @@ def all_states(n_qubits: int) -> Iterator[QubitState]:
def optimal_grover_iterations(n_qubits: int) -> int:
"""Return the optimal number of Grover iterations for n qubits."""
return floor(pi / 4 * sqrt(2**n_qubits))
return floor(pi / 4 * sqrt(2.0**n_qubits))
def is_optimal_iteration(iteration: int, optimal_iteration: int) -> bool:

View File

@ -0,0 +1,84 @@
from math import asin, sqrt
from typing import TYPE_CHECKING
import matplotlib.pyplot as plt
from matplotlib.backend_bases import Event, KeyEvent
from matplotlib.gridspec import GridSpec
from qiskit.quantum_info import Statevector
from grovers_visualizer.plot import SinePlotData, plot_amplitudes, plot_circle, plot_sine
from grovers_visualizer.state import QubitState
from grovers_visualizer.utils import all_states, optimal_grover_iterations
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.container import BarContainer
from matplotlib.figure import Figure
class GroverVisualizer:
def __init__(self, target: QubitState, pause: float = 0.5) -> None:
self.target: QubitState = target
self.n: int = len(self.target)
self.basis_states: list[str] = [str(b) for b in all_states(self.n)]
self.optimal: int = optimal_grover_iterations(self.n)
self.theta: float = 2 * asin(1 / sqrt(2.0**self.n))
self.state_angle: float = 0.5 * self.theta
self.sine_data: SinePlotData = SinePlotData()
self.is_running: bool = True
self.pause: float = pause
self._build_figure()
def _build_figure(self) -> None:
plt.ion()
self.fig: Figure = plt.figure(figsize=(14, 6))
gs = GridSpec(2, 2, width_ratios=(3, 1), figure=self.fig)
self.ax_bar: Axes = self.fig.add_subplot(gs[0, 0])
self.ax_sine: Axes = self.fig.add_subplot(gs[1, 0])
self.ax_circle: Axes = self.fig.add_subplot(gs[:, 1])
# bars
self.bars: BarContainer = self.ax_bar.bar(self.basis_states, [0] * len(self.basis_states), color="skyblue")
self.ax_bar.set_ylim(-1, 1)
self.ax_bar.set_title("Amplitudes (example)")
# key handler to quit
self.cid: int = self.fig.canvas.mpl_connect("key_press_event", self._on_key)
def _on_key(self, event: Event) -> None:
if isinstance(event, KeyEvent) and event.key == "q":
self.is_running = False
def update(self, iteration: int, sv: Statevector) -> None:
"""Given (iteration, Statevector), update all three plots."""
# amplitudes
plot_amplitudes(
self.ax_bar,
self.bars,
sv,
self.basis_states,
"Grover Iteration",
iteration,
self.target,
self.optimal,
)
# circle
plot_circle(
self.ax_circle,
iteration,
self.optimal,
self.theta,
self.state_angle,
)
# sine curve
self.sine_data.calc_and_append_probability(iteration, self.theta)
plot_sine(self.ax_sine, self.sine_data)
plt.pause(self.pause)
def finalize(self) -> None:
"""Clean up after loop ends."""
self.fig.canvas.mpl_disconnect(self.cid)
plt.ioff()