From c928ad55a97759a746fe4d3bc0731d5a0b81300f Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Tue, 22 Apr 2025 15:38:39 +0300 Subject: [PATCH] feat: add step visualization --- src/grovers_visualizer/main.py | 102 ++++++++++++++++++++++---------- src/grovers_visualizer/state.py | 11 +++- 2 files changed, 81 insertions(+), 32 deletions(-) diff --git a/src/grovers_visualizer/main.py b/src/grovers_visualizer/main.py index acb4d32..181e3fb 100644 --- a/src/grovers_visualizer/main.py +++ b/src/grovers_visualizer/main.py @@ -11,9 +11,11 @@ from itertools import product from math import floor, pi, sqrt import matplotlib.pyplot as plt +import numpy as np from matplotlib.axes import Axes +from matplotlib.container import BarContainer from qiskit import QuantumCircuit -from qiskit_aer import AerSimulator +from qiskit.quantum_info import Statevector from grovers_visualizer.gates import apply_phase_inversion, encode_target_state from grovers_visualizer.state import QubitState @@ -75,40 +77,78 @@ def all_states(n_qubits: int) -> Iterator[QubitState]: yield QubitState("".join(bits)) +def plot_amplitudes_live( + ax: Axes, + bars: BarContainer, + statevector: Statevector, + basis_states: list[str], + step_label: str, + iteration: int, + target_state: QubitState | None = None, + optimal_iteration: int | None = None, +) -> None: + amplitudes = statevector.data.real # Real part of amplitudes + mean = np.mean(amplitudes) + for bar, state, amp in zip(bars, basis_states, amplitudes, strict=False): + bar.set_height(amp) + if state == target_state: + if optimal_iteration is not None and iteration == optimal_iteration: + bar.set_color("green") + else: + bar.set_color("orange") + else: + bar.set_color("skyblue") + ax.set_title(f"Iteration {iteration}: {step_label}") + ax.set_ylim(-1, 1) + # Remove previous mean line(s) + + for l in ax.lines: + l.remove() + + ax.axhline(mean, color="red", linestyle="--", label="Mean") + + if not ax.get_legend(): + ax.legend(loc="upper right") + plt.pause(1) + + def main() -> None: - shots = 128 - target = QubitState("1010") - n_qubits = len(target) - - qc = grover_search(target, iterations=1) - simulator = AerSimulator() - job = simulator.run(qc, shots=shots, memory=True) - result = job.result() - memory = result.get_memory(qc) # List of measurement results, one per shot - - print(qc) # draw scheme - - print(f"Target: {target}") - - # Ensure all possible states are present in the bar chart - all_states = ["".join(bits) for bits in product("01", repeat=n_qubits)] - counts = dict.fromkeys(all_states, 0) + target_state = QubitState("1010") + n_qubits = len(target_state) + basis_states = [str(bit) for bit in all_states(n_qubits)] + optimal_iterations = floor(pi / 4 * sqrt(2**n_qubits)) plt.ion() - _, ax = plt.subplots(figsize=(6, 2)) - bars = ax.bar(all_states, [0] * len(all_states), color="skyblue") - ax.set_xlabel("Measured State") - ax.set_ylabel("Counts") - ax.set_title(f"Measurement Variability for Target: {target}") - ax.set_ylim(0, shots) + fig, ax = plt.subplots(figsize=(8, 3)) + bars = ax.bar(basis_states, [0] * len(basis_states), color="skyblue") + ax.set_xlabel("Basis State") + ax.set_ylabel("Real Amplitude") + ax.set_ylim(-1, 1) + ax.set_title("Grover Amplitudes") - for i, measured in enumerate(memory, 1): - counts[measured] += 1 - for bar, state in zip(bars, all_states, strict=False): - bar.set_height(counts[state]) - bar.set_color("orange" if state == str(target) else "skyblue") - ax.set_title(f"Measurement Variability (Shot {i}/{shots})\nTarget: {target}") - plt.pause(0.5) + # Start with Hadamard + qc = QuantumCircuit(n_qubits) + qc.h(range(n_qubits)) + sv = Statevector.from_instruction(qc) + plot_amplitudes_live(ax, bars, sv, basis_states, "Hadamard (Initialization)", 0, target_state, optimal_iterations) + + iteration = 1 + while plt.fignum_exists(fig.number): + # Oracle phase + oracle(qc, target_state) + sv = Statevector.from_instruction(qc) + plot_amplitudes_live( + ax, bars, sv, basis_states, "Oracle (Query Phase)", iteration, target_state, optimal_iterations + ) + + # Diffusion phase + diffusion(qc, n_qubits) + sv = Statevector.from_instruction(qc) + plot_amplitudes_live( + ax, bars, sv, basis_states, "Diffusion (Inversion Phase)", iteration, target_state, optimal_iterations + ) + + iteration += 1 plt.ioff() plt.show() diff --git a/src/grovers_visualizer/state.py b/src/grovers_visualizer/state.py index 1866963..9909e07 100644 --- a/src/grovers_visualizer/state.py +++ b/src/grovers_visualizer/state.py @@ -29,7 +29,16 @@ class QubitState: def __eq__(self, value: object, /) -> bool: if isinstance(value, QubitState): return self.bits == value.bits - return False + if isinstance(value, str): + return self.bits == value + return NotImplemented + + def __lt__(self, value: object, /) -> bool: + if isinstance(value, QubitState): + return int(self.bits, 2) < int(value.bits, 2) + if isinstance(value, str) and all(b in "01" for b in value): + return int(self.bits, 2) < int(value, 2) + return NotImplemented @override def __hash__(self) -> int: