diff --git a/src/grovers_visualizer/abc/__init__.py b/src/grovers_visualizer/abc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/grovers_visualizer/abc/visualization.py b/src/grovers_visualizer/abc/visualization.py new file mode 100644 index 0000000..488110b --- /dev/null +++ b/src/grovers_visualizer/abc/visualization.py @@ -0,0 +1,47 @@ +import math +from abc import ABC, abstractmethod +from math import asin, sqrt +from typing import Self + +from grovers_visualizer.args import Args +from grovers_visualizer.plot import SinePlotData +from grovers_visualizer.state import QubitState +from grovers_visualizer.utils import all_states, optimal_grover_iterations + + +class BaseGroverVisualizer(ABC): + def __init__(self, target: QubitState, *, iterations: int = 0, pause: float = 0.5, phase: float = math.pi) -> None: + self.target: QubitState = target + self.n: int = len(self.target) + self.basis_states: list[str] = [str(b) for b in all_states(self.n)] + self.optimal: int = optimal_grover_iterations(self.n) + self.theta: float = 2 * asin(1 / sqrt(2.0**self.n)) + self.state_angle: float = 0.5 * self.theta + self.sine_data: SinePlotData = SinePlotData() + self.is_running: bool = True + self.pause: float = pause + self.iterations: int = iterations + self.phase: float = phase + self._init_figure() + + @abstractmethod + def _init_figure(self) -> None: + raise NotImplementedError + + @abstractmethod + def update(self) -> None: + """Given (iteration, Statevector), update all three plots.""" + raise NotImplementedError + + @abstractmethod + def finalize(self) -> None: + raise NotImplementedError + + @classmethod + def from_args(cls, args: Args) -> Self: + return cls( + args.target, + iterations=args.iterations, + pause=args.speed, + phase=args.phase, + ) diff --git a/src/grovers_visualizer/cli.py b/src/grovers_visualizer/cli.py index 531219f..441eea9 100644 --- a/src/grovers_visualizer/cli.py +++ b/src/grovers_visualizer/cli.py @@ -1,15 +1,9 @@ -from grovers_visualizer.simulation import grover_evolver from grovers_visualizer.visualization import GroverVisualizer from .args import Args def run_cli(args: Args) -> None: - vis = GroverVisualizer(args.target, pause=args.speed) - - for it, sv in grover_evolver(vis.target, args.iterations, phase=args.phase): - if not vis.is_running: - break - vis.update(it, sv) - + vis = GroverVisualizer.from_args(args) + vis.update() vis.finalize() diff --git a/src/grovers_visualizer/visualization.py b/src/grovers_visualizer/visualization.py index 0aa72fe..8b0e462 100644 --- a/src/grovers_visualizer/visualization.py +++ b/src/grovers_visualizer/visualization.py @@ -1,14 +1,13 @@ -from math import asin, sqrt -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, final, override import matplotlib.pyplot as plt from matplotlib.backend_bases import Event, KeyEvent from matplotlib.gridspec import GridSpec -from qiskit.quantum_info import Statevector -from grovers_visualizer.plot import SinePlotData, plot_amplitudes, plot_circle, plot_sine -from grovers_visualizer.state import QubitState -from grovers_visualizer.utils import all_states, optimal_grover_iterations +from grovers_visualizer.simulation import grover_evolver + +from .abc.visualization import BaseGroverVisualizer +from .plot import plot_amplitudes, plot_circle, plot_sine if TYPE_CHECKING: from matplotlib.axes import Axes @@ -16,20 +15,10 @@ if TYPE_CHECKING: from matplotlib.figure import Figure -class GroverVisualizer: - def __init__(self, target: QubitState, pause: float = 0.5) -> None: - self.target: QubitState = target - self.n: int = len(self.target) - self.basis_states: list[str] = [str(b) for b in all_states(self.n)] - self.optimal: int = optimal_grover_iterations(self.n) - self.theta: float = 2 * asin(1 / sqrt(2.0**self.n)) - self.state_angle: float = 0.5 * self.theta - self.sine_data: SinePlotData = SinePlotData() - self.is_running: bool = True - self.pause: float = pause - self._build_figure() - - def _build_figure(self) -> None: +@final +class GroverVisualizer(BaseGroverVisualizer): + @override + def _init_figure(self) -> None: plt.ion() self.fig: Figure = plt.figure(figsize=(14, 6)) gs = GridSpec(2, 2, width_ratios=(3, 1), figure=self.fig) @@ -49,35 +38,40 @@ class GroverVisualizer: if isinstance(event, KeyEvent) and event.key == "q": self.is_running = False - def update(self, iteration: int, sv: Statevector) -> None: + @override + def update(self) -> None: """Given (iteration, Statevector), update all three plots.""" - # amplitudes - plot_amplitudes( - self.ax_bar, - self.bars, - sv, - self.basis_states, - "Grover Iteration", - iteration, - self.target, - self.optimal, - ) + for it, sv in grover_evolver(self.target, self.iterations, phase=self.phase): + if not self.is_running: + break + # amplitudes + plot_amplitudes( + self.ax_bar, + self.bars, + sv, + self.basis_states, + "Grover Iteration", + it, + self.target, + self.optimal, + ) - # circle - plot_circle( - self.ax_circle, - iteration, - self.optimal, - self.theta, - self.state_angle, - ) + # circle + plot_circle( + self.ax_circle, + it, + self.optimal, + self.theta, + self.state_angle, + ) - # sine curve - self.sine_data.calc_and_append_probability(iteration, self.theta) + # sine curve + self.sine_data.calc_and_append_probability(it, self.theta) - plot_sine(self.ax_sine, self.sine_data) - plt.pause(self.pause) + plot_sine(self.ax_sine, self.sine_data) + plt.pause(self.pause) + @override def finalize(self) -> None: """Clean up after loop ends.""" self.fig.canvas.mpl_disconnect(self.cid) diff --git a/uv.lock b/uv.lock index b73a1ae..8ca7212 100644 --- a/uv.lock +++ b/uv.lock @@ -155,7 +155,7 @@ wheels = [ [[package]] name = "grovers-visualizer" -version = "0.4.2" +version = "0.5.0" source = { editable = "." } dependencies = [ { name = "numpy" },