mirror of
https://github.com/kristoferssolo/grovers-visualizer.git
synced 2025-10-21 20:10:35 +00:00
refactor: compact down function calls
This commit is contained in:
parent
c928ad55a9
commit
2defbc9d77
@ -9,9 +9,11 @@ using matplotlib.
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from math import floor, pi, sqrt
|
from math import floor, pi, sqrt
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from matplotlib.container import BarContainer
|
from matplotlib.container import BarContainer
|
||||||
from qiskit import QuantumCircuit
|
from qiskit import QuantumCircuit
|
||||||
@ -77,6 +79,20 @@ def all_states(n_qubits: int) -> Iterator[QubitState]:
|
|||||||
yield QubitState("".join(bits))
|
yield QubitState("".join(bits))
|
||||||
|
|
||||||
|
|
||||||
|
def optimal_grover_iterations(n_qubits: int) -> int:
|
||||||
|
"""Return the optimal number of Grover iterations for n qubits."""
|
||||||
|
return floor(pi / 4 * sqrt(2**n_qubits))
|
||||||
|
|
||||||
|
|
||||||
|
def get_bar_color(state: str, target_state: QubitState | None, iteration: int, optimal_iteration: int | None) -> str:
|
||||||
|
"""Return the color for a bar based on state and iteration."""
|
||||||
|
if state != target_state:
|
||||||
|
return "skyblue"
|
||||||
|
if optimal_iteration and iteration % optimal_iteration == 0 and iteration != 0:
|
||||||
|
return "green"
|
||||||
|
return "orange"
|
||||||
|
|
||||||
|
|
||||||
def plot_amplitudes_live(
|
def plot_amplitudes_live(
|
||||||
ax: Axes,
|
ax: Axes,
|
||||||
bars: BarContainer,
|
bars: BarContainer,
|
||||||
@ -87,25 +103,20 @@ def plot_amplitudes_live(
|
|||||||
target_state: QubitState | None = None,
|
target_state: QubitState | None = None,
|
||||||
optimal_iteration: int | None = None,
|
optimal_iteration: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
amplitudes = statevector.data.real # Real part of amplitudes
|
amplitudes: npt.NDArray[np.float64] = statevector.data.real # Real part of amplitudes
|
||||||
mean = np.mean(amplitudes)
|
mean = np.mean(amplitudes)
|
||||||
|
|
||||||
for bar, state, amp in zip(bars, basis_states, amplitudes, strict=False):
|
for bar, state, amp in zip(bars, basis_states, amplitudes, strict=False):
|
||||||
bar.set_height(amp)
|
bar.set_height(amp)
|
||||||
if state == target_state:
|
bar.set_color(get_bar_color(state, target_state, iteration, optimal_iteration))
|
||||||
if optimal_iteration is not None and iteration == optimal_iteration:
|
|
||||||
bar.set_color("green")
|
|
||||||
else:
|
|
||||||
bar.set_color("orange")
|
|
||||||
else:
|
|
||||||
bar.set_color("skyblue")
|
|
||||||
ax.set_title(f"Iteration {iteration}: {step_label}")
|
ax.set_title(f"Iteration {iteration}: {step_label}")
|
||||||
ax.set_ylim(-1, 1)
|
ax.set_ylim(-1, 1)
|
||||||
# Remove previous mean line(s)
|
|
||||||
|
|
||||||
for l in ax.lines:
|
for l in ax.lines: # Remove previous mean line(s)
|
||||||
l.remove()
|
l.remove()
|
||||||
|
|
||||||
ax.axhline(mean, color="red", linestyle="--", label="Mean")
|
ax.axhline(float(mean), color="red", linestyle="--", label="Mean")
|
||||||
|
|
||||||
if not ax.get_legend():
|
if not ax.get_legend():
|
||||||
ax.legend(loc="upper right")
|
ax.legend(loc="upper right")
|
||||||
@ -116,7 +127,7 @@ def main() -> None:
|
|||||||
target_state = QubitState("1010")
|
target_state = QubitState("1010")
|
||||||
n_qubits = len(target_state)
|
n_qubits = len(target_state)
|
||||||
basis_states = [str(bit) for bit in all_states(n_qubits)]
|
basis_states = [str(bit) for bit in all_states(n_qubits)]
|
||||||
optimal_iterations = floor(pi / 4 * sqrt(2**n_qubits))
|
optimal_iterations = optimal_grover_iterations(n_qubits)
|
||||||
|
|
||||||
plt.ion()
|
plt.ion()
|
||||||
fig, ax = plt.subplots(figsize=(8, 3))
|
fig, ax = plt.subplots(figsize=(8, 3))
|
||||||
@ -126,28 +137,25 @@ def main() -> None:
|
|||||||
ax.set_ylim(-1, 1)
|
ax.set_ylim(-1, 1)
|
||||||
ax.set_title("Grover Amplitudes")
|
ax.set_title("Grover Amplitudes")
|
||||||
|
|
||||||
|
def step_and_plot(
|
||||||
|
operation: Callable[[QuantumCircuit], None] | None,
|
||||||
|
step_label: str,
|
||||||
|
iteration: int,
|
||||||
|
) -> None:
|
||||||
|
if operation is not None:
|
||||||
|
operation(qc)
|
||||||
|
sv = Statevector.from_instruction(qc)
|
||||||
|
plot_amplitudes_live(ax, bars, sv, basis_states, step_label, iteration, target_state, optimal_iterations)
|
||||||
|
|
||||||
# Start with Hadamard
|
# Start with Hadamard
|
||||||
qc = QuantumCircuit(n_qubits)
|
qc = QuantumCircuit(n_qubits)
|
||||||
qc.h(range(n_qubits))
|
qc.h(range(n_qubits))
|
||||||
sv = Statevector.from_instruction(qc)
|
step_and_plot(None, "Hadamard (Initialization)", 0)
|
||||||
plot_amplitudes_live(ax, bars, sv, basis_states, "Hadamard (Initialization)", 0, target_state, optimal_iterations)
|
|
||||||
|
|
||||||
iteration = 1
|
iteration = 1
|
||||||
while plt.fignum_exists(fig.number):
|
while plt.fignum_exists(fig.number):
|
||||||
# Oracle phase
|
step_and_plot(lambda qc: oracle(qc, target_state), "Oracle (Query Phase)", iteration)
|
||||||
oracle(qc, target_state)
|
step_and_plot(lambda qc: diffusion(qc, n_qubits), "Diffusion (Inversion Phase)", iteration)
|
||||||
sv = Statevector.from_instruction(qc)
|
|
||||||
plot_amplitudes_live(
|
|
||||||
ax, bars, sv, basis_states, "Oracle (Query Phase)", iteration, target_state, optimal_iterations
|
|
||||||
)
|
|
||||||
|
|
||||||
# Diffusion phase
|
|
||||||
diffusion(qc, n_qubits)
|
|
||||||
sv = Statevector.from_instruction(qc)
|
|
||||||
plot_amplitudes_live(
|
|
||||||
ax, bars, sv, basis_states, "Diffusion (Inversion Phase)", iteration, target_state, optimal_iterations
|
|
||||||
)
|
|
||||||
|
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
|
||||||
plt.ioff()
|
plt.ioff()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user