refactor(cli): hide update implementation

This commit is contained in:
Kristofers Solo 2025-05-13 11:28:39 +03:00
parent a405df416c
commit b364eddbc4
Signed by: kristoferssolo
GPG Key ID: 74FF8144483D82C8
5 changed files with 88 additions and 53 deletions

View File

View File

@ -0,0 +1,47 @@
import math
from abc import ABC, abstractmethod
from math import asin, sqrt
from typing import Self
from grovers_visualizer.args import Args
from grovers_visualizer.plot import SinePlotData
from grovers_visualizer.state import QubitState
from grovers_visualizer.utils import all_states, optimal_grover_iterations
class BaseGroverVisualizer(ABC):
def __init__(self, target: QubitState, *, iterations: int = 0, pause: float = 0.5, phase: float = math.pi) -> None:
self.target: QubitState = target
self.n: int = len(self.target)
self.basis_states: list[str] = [str(b) for b in all_states(self.n)]
self.optimal: int = optimal_grover_iterations(self.n)
self.theta: float = 2 * asin(1 / sqrt(2.0**self.n))
self.state_angle: float = 0.5 * self.theta
self.sine_data: SinePlotData = SinePlotData()
self.is_running: bool = True
self.pause: float = pause
self.iterations: int = iterations
self.phase: float = phase
self._init_figure()
@abstractmethod
def _init_figure(self) -> None:
raise NotImplementedError
@abstractmethod
def update(self) -> None:
"""Given (iteration, Statevector), update all three plots."""
raise NotImplementedError
@abstractmethod
def finalize(self) -> None:
raise NotImplementedError
@classmethod
def from_args(cls, args: Args) -> Self:
return cls(
args.target,
iterations=args.iterations,
pause=args.speed,
phase=args.phase,
)

View File

@ -1,15 +1,9 @@
from grovers_visualizer.simulation import grover_evolver
from grovers_visualizer.visualization import GroverVisualizer
from .args import Args
def run_cli(args: Args) -> None:
vis = GroverVisualizer(args.target, pause=args.speed)
for it, sv in grover_evolver(vis.target, args.iterations, phase=args.phase):
if not vis.is_running:
break
vis.update(it, sv)
vis = GroverVisualizer.from_args(args)
vis.update()
vis.finalize()

View File

@ -1,14 +1,13 @@
from math import asin, sqrt
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, final, override
import matplotlib.pyplot as plt
from matplotlib.backend_bases import Event, KeyEvent
from matplotlib.gridspec import GridSpec
from qiskit.quantum_info import Statevector
from grovers_visualizer.plot import SinePlotData, plot_amplitudes, plot_circle, plot_sine
from grovers_visualizer.state import QubitState
from grovers_visualizer.utils import all_states, optimal_grover_iterations
from grovers_visualizer.simulation import grover_evolver
from .abc.visualization import BaseGroverVisualizer
from .plot import plot_amplitudes, plot_circle, plot_sine
if TYPE_CHECKING:
from matplotlib.axes import Axes
@ -16,20 +15,10 @@ if TYPE_CHECKING:
from matplotlib.figure import Figure
class GroverVisualizer:
def __init__(self, target: QubitState, pause: float = 0.5) -> None:
self.target: QubitState = target
self.n: int = len(self.target)
self.basis_states: list[str] = [str(b) for b in all_states(self.n)]
self.optimal: int = optimal_grover_iterations(self.n)
self.theta: float = 2 * asin(1 / sqrt(2.0**self.n))
self.state_angle: float = 0.5 * self.theta
self.sine_data: SinePlotData = SinePlotData()
self.is_running: bool = True
self.pause: float = pause
self._build_figure()
def _build_figure(self) -> None:
@final
class GroverVisualizer(BaseGroverVisualizer):
@override
def _init_figure(self) -> None:
plt.ion()
self.fig: Figure = plt.figure(figsize=(14, 6))
gs = GridSpec(2, 2, width_ratios=(3, 1), figure=self.fig)
@ -49,35 +38,40 @@ class GroverVisualizer:
if isinstance(event, KeyEvent) and event.key == "q":
self.is_running = False
def update(self, iteration: int, sv: Statevector) -> None:
@override
def update(self) -> None:
"""Given (iteration, Statevector), update all three plots."""
# amplitudes
plot_amplitudes(
self.ax_bar,
self.bars,
sv,
self.basis_states,
"Grover Iteration",
iteration,
self.target,
self.optimal,
)
for it, sv in grover_evolver(self.target, self.iterations, phase=self.phase):
if not self.is_running:
break
# amplitudes
plot_amplitudes(
self.ax_bar,
self.bars,
sv,
self.basis_states,
"Grover Iteration",
it,
self.target,
self.optimal,
)
# circle
plot_circle(
self.ax_circle,
iteration,
self.optimal,
self.theta,
self.state_angle,
)
# circle
plot_circle(
self.ax_circle,
it,
self.optimal,
self.theta,
self.state_angle,
)
# sine curve
self.sine_data.calc_and_append_probability(iteration, self.theta)
# sine curve
self.sine_data.calc_and_append_probability(it, self.theta)
plot_sine(self.ax_sine, self.sine_data)
plt.pause(self.pause)
plot_sine(self.ax_sine, self.sine_data)
plt.pause(self.pause)
@override
def finalize(self) -> None:
"""Clean up after loop ends."""
self.fig.canvas.mpl_disconnect(self.cid)

View File

@ -155,7 +155,7 @@ wheels = [
[[package]]
name = "grovers-visualizer"
version = "0.4.2"
version = "0.5.0"
source = { editable = "." }
dependencies = [
{ name = "numpy" },