diff --git a/src/grovers_visualizer/circuit.py b/src/grovers_visualizer/circuit.py index 9539223..7b4e219 100644 --- a/src/grovers_visualizer/circuit.py +++ b/src/grovers_visualizer/circuit.py @@ -2,7 +2,6 @@ from math import floor, pi, sqrt from qiskit import QuantumCircuit -from .gates import apply_phase_inversion, encode_target_state from .state import QubitState @@ -39,3 +38,20 @@ def diffusion(qc: QuantumCircuit, n: int) -> None: apply_phase_inversion(qc, n) qc.x(range(n)) qc.h(range(n)) + + +def encode_target_state(qc: QuantumCircuit, target_state: QubitState) -> None: + """Apply X gates to qubits where the target state bit is '0'.""" + for i, bit in enumerate(reversed(target_state)): + if bit == 0: + qc.x(i) + + +def apply_phase_inversion(qc: QuantumCircuit, n: int) -> None: + """Apply a multi-controlled phase inversion (Z) to the marked state.""" + if n == 1: + qc.z(0) + return + qc.h(n - 1) + qc.mcx(list(range(n - 1)), n - 1) + qc.h(n - 1) diff --git a/src/grovers_visualizer/gates.py b/src/grovers_visualizer/gates.py index 2a3c056..e69de29 100644 --- a/src/grovers_visualizer/gates.py +++ b/src/grovers_visualizer/gates.py @@ -1,20 +0,0 @@ -from qiskit import QuantumCircuit - -from .state import QubitState - - -def encode_target_state(qc: QuantumCircuit, target_state: QubitState) -> None: - """Apply X gates to qubits where the target state bit is '0'.""" - for i, bit in enumerate(reversed(target_state)): - if bit == 0: - qc.x(i) - - -def apply_phase_inversion(qc: QuantumCircuit, n: int) -> None: - """Apply a multi-controlled phase inversion (Z) to the marked state.""" - if n == 1: - qc.z(0) - return - qc.h(n - 1) - qc.mcx(list(range(n - 1)), n - 1) - qc.h(n - 1) diff --git a/src/grovers_visualizer/main.py b/src/grovers_visualizer/main.py index a7da90a..b15ec2a 100644 --- a/src/grovers_visualizer/main.py +++ b/src/grovers_visualizer/main.py @@ -7,22 +7,19 @@ using matplotlib. """ from math import asin, sqrt -from typing import TYPE_CHECKING, Callable +from typing import Callable import matplotlib.pyplot as plt from matplotlib.backend_bases import KeyEvent +from matplotlib.gridspec import GridSpec from qiskit import QuantumCircuit from qiskit.quantum_info import Statevector from grovers_visualizer.circuit import diffusion, oracle from grovers_visualizer.parse import parse_args -from grovers_visualizer.plot import draw_grover_circle, plot_amplitudes_live +from grovers_visualizer.plot import SinePlotData, plot_amplitudes, plot_circle, plot_sine from grovers_visualizer.utils import all_states, optimal_grover_iterations -if TYPE_CHECKING: - from matplotlib.axes import Axes - from matplotlib.figure import Figure - def main() -> None: args = parse_args() @@ -35,13 +32,18 @@ def main() -> None: state_angle = 0.5 * theta plt.ion() - subplt: tuple[Figure, tuple[Axes, Axes]] = plt.subplots(1, 2, width_ratios=(3, 1), figsize=(12, 4)) - fig, (ax_bar, ax_circle) = subplt + fig = plt.figure(figsize=(14, 6)) + gs = GridSpec(2, 2, width_ratios=(3, 1), figure=fig) + ax_bar = fig.add_subplot(gs[0, 0]) + ax_sine = fig.add_subplot(gs[1, 0]) + ax_circle = fig.add_subplot(gs[:, 1]) bars = ax_bar.bar(basis_states, [0] * len(basis_states), color="skyblue") ax_bar.set_ylim(-1, 1) ax_bar.set_title("Amplitudes (example)") - def iterate_and_plot( + sine_data = SinePlotData() + + def plot_bar( operation: Callable[[QuantumCircuit], None] | None, step_label: str, iteration: int, @@ -49,15 +51,12 @@ def main() -> None: if operation is not None: operation(qc) sv = Statevector.from_instruction(qc) - plot_amplitudes_live(ax_bar, bars, sv, basis_states, step_label, iteration, target_state, optimal_iterations) - draw_grover_circle(ax_circle, iteration, optimal_iterations, theta, state_angle) - - plt.pause(args.speed) + plot_amplitudes(ax_bar, bars, sv, basis_states, step_label, iteration, target_state, optimal_iterations) # Start with Hadamard qc = QuantumCircuit(n_qubits) qc.h(range(n_qubits)) - iterate_and_plot(None, "Hadamard (Initialization)", 0) + plot_bar(None, "Hadamard (Initialization)", 0) iteration = 1 running = True @@ -69,8 +68,14 @@ def main() -> None: cid = fig.canvas.mpl_connect("key_press_event", on_key) while plt.fignum_exists(fig.number) and running: - iterate_and_plot(lambda qc: oracle(qc, target_state), "Oracle (Query Phase)", iteration) - iterate_and_plot(lambda qc: diffusion(qc, n_qubits), "Diffusion (Inversion Phase)", iteration) + plot_bar(lambda qc: oracle(qc, target_state), "Oracle (Query Phase)", iteration) + plot_bar(lambda qc: diffusion(qc, n_qubits), "Diffusion (Inversion Phase)", iteration) + + plot_circle(ax_circle, iteration, optimal_iterations, theta, state_angle) + sine_data.calc_and_append_probability(iteration, theta) + plot_sine(ax_sine, sine_data) + + plt.pause(args.speed) iteration += 1 if args.iterations > 0 and iteration > args.iterations: diff --git a/src/grovers_visualizer/plot.py b/src/grovers_visualizer/plot.py index 77eda39..e5788a5 100644 --- a/src/grovers_visualizer/plot.py +++ b/src/grovers_visualizer/plot.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field from math import cos, sin import numpy as np @@ -11,7 +12,7 @@ from .state import QubitState from .utils import get_bar_color, is_optimal_iteration, sign -def plot_amplitudes_live( +def plot_amplitudes( ax: Axes, bars: BarContainer, statevector: Statevector, @@ -42,7 +43,7 @@ def plot_amplitudes_live( ax.legend(loc="upper right") -def draw_grover_circle( +def plot_circle( ax: Axes, iteration: int, optimal_iterations: int, @@ -98,3 +99,31 @@ def draw_grover_circle( ax.set_title( f"Grover State Vector Rotation\nIteration {iteration} | Probability of target: {prob}{' (optimal)' if is_optimal else ''}" ) + + +@dataclass +class SinePlotData: + x: list[float] = field(default_factory=list) + y: list[float] = field(default_factory=list) + + def append(self, x: float, y: float) -> None: + self.x.append(x) + self.y.append(y) + + def calc_and_append_probability(self, iteration: int, theta: float) -> None: + prob = sin((2 * iteration + 1) * theta / 2) ** 2 + self.append(iteration, prob) + + +def plot_sine( + ax: Axes, + sine_data: SinePlotData, +) -> None: + ax.clear() + ax.plot(sine_data.x, sine_data.y, marker="o", color="purple", label="Target Probability") + ax.set_xlabel("Iteration") + ax.set_ylabel("Probability") + ax.set_title("Grover Target Probability vs. Iteration") + ax.set_ylim(0, 1) + ax.set_xlim(0, max(10, max(sine_data.x) + 1)) + ax.legend()