mirror of
https://github.com/kristoferssolo/grovers-visualizer.git
synced 2026-02-04 06:42:03 +00:00
feat(args): add args functionality
This commit is contained in:
@@ -10,36 +10,23 @@ from math import asin, sqrt
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.axes import Axes
|
||||
from qiskit import QuantumCircuit
|
||||
from qiskit.quantum_info import Statevector
|
||||
|
||||
from grovers_visualizer.circuit import diffusion, oracle
|
||||
from grovers_visualizer.parse import parse_args
|
||||
from grovers_visualizer.plot import draw_grover_circle, plot_amplitudes_live
|
||||
from grovers_visualizer.state import QubitState
|
||||
from grovers_visualizer.utils import all_states, optimal_grover_iterations
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
|
||||
def plot_counts(ax: Axes, counts: dict[str, int], target_state: QubitState) -> None:
|
||||
"""Display a bar chart for the measurement results."""
|
||||
|
||||
# Sort the states
|
||||
states = list(counts.keys())
|
||||
frequencies = [counts[s] for s in states]
|
||||
|
||||
ax.clear()
|
||||
ax.bar(states, frequencies, color="skyblue")
|
||||
ax.set_xlabel("Measured State")
|
||||
ax.set_ylabel("Counts")
|
||||
ax.set_title(f"Measurement Counts for Target: {target_state}")
|
||||
ax.set_ylim(0, max(frequencies) * 1.2)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
target_state = QubitState("1010")
|
||||
args = parse_args()
|
||||
|
||||
target_state = args.target
|
||||
n_qubits = len(target_state)
|
||||
basis_states = [str(bit) for bit in all_states(n_qubits)]
|
||||
optimal_iterations = optimal_grover_iterations(n_qubits)
|
||||
@@ -66,7 +53,7 @@ def main() -> None:
|
||||
plot_amplitudes_live(ax_bar, bars, sv, basis_states, step_label, iteration, target_state, optimal_iterations)
|
||||
draw_grover_circle(ax_circle, iteration, optimal_iterations, theta, state_angle)
|
||||
|
||||
plt.pause(0.5)
|
||||
plt.pause(args.speed)
|
||||
|
||||
# Start with Hadamard
|
||||
qc = QuantumCircuit(n_qubits)
|
||||
@@ -74,14 +61,24 @@ def main() -> None:
|
||||
iterate_and_plot(None, "Hadamard (Initialization)", 0)
|
||||
|
||||
iteration = 1
|
||||
while plt.fignum_exists(fig.number):
|
||||
running = True
|
||||
|
||||
def on_key(event) -> None:
|
||||
nonlocal running
|
||||
if event.key == "q":
|
||||
running = False
|
||||
|
||||
cid = fig.canvas.mpl_connect("key_press_event", on_key)
|
||||
while plt.fignum_exists(fig.number) and running:
|
||||
iterate_and_plot(lambda qc: oracle(qc, target_state), "Oracle (Query Phase)", iteration)
|
||||
iterate_and_plot(lambda qc: diffusion(qc, n_qubits), "Diffusion (Inversion Phase)", iteration)
|
||||
|
||||
iteration += 1
|
||||
if args.iterations > 0 and iteration > args.iterations:
|
||||
break
|
||||
|
||||
fig.canvas.mpl_disconnect(cid)
|
||||
plt.ioff()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
40
src/grovers_visualizer/parse.py
Normal file
40
src/grovers_visualizer/parse.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import dataclass
|
||||
|
||||
from grovers_visualizer.state import QubitState
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
target: QubitState
|
||||
iterations: int
|
||||
speed: float
|
||||
|
||||
|
||||
def parse_args() -> Args:
|
||||
parser = ArgumentParser(description="Grover's Algorithm Visualizer")
|
||||
parser.add_argument(
|
||||
"target",
|
||||
type=str,
|
||||
help="Target bitstring (e.g., 1010)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of Grover iterations (default: 0 (infinite))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--speed",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Pause duration (seconds) between steps (deafult: 0.5)",
|
||||
)
|
||||
ns = parser.parse_args()
|
||||
return Args(
|
||||
target=QubitState(ns.target),
|
||||
iterations=ns.iterations,
|
||||
speed=ns.speed,
|
||||
)
|
||||
Reference in New Issue
Block a user