refactor: compact down function calls

This commit is contained in:
Kristofers Solo 2025-04-22 16:03:44 +03:00
parent c928ad55a9
commit 2defbc9d77

View File

@ -9,9 +9,11 @@ using matplotlib.
from collections.abc import Iterator
from itertools import product
from math import floor, pi, sqrt
from typing import Callable
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from matplotlib.axes import Axes
from matplotlib.container import BarContainer
from qiskit import QuantumCircuit
@ -77,6 +79,20 @@ def all_states(n_qubits: int) -> Iterator[QubitState]:
yield QubitState("".join(bits))
def optimal_grover_iterations(n_qubits: int) -> int:
"""Return the optimal number of Grover iterations for n qubits."""
return floor(pi / 4 * sqrt(2**n_qubits))
def get_bar_color(state: str, target_state: QubitState | None, iteration: int, optimal_iteration: int | None) -> str:
"""Return the color for a bar based on state and iteration."""
if state != target_state:
return "skyblue"
if optimal_iteration and iteration % optimal_iteration == 0 and iteration != 0:
return "green"
return "orange"
def plot_amplitudes_live(
ax: Axes,
bars: BarContainer,
@ -87,25 +103,20 @@ def plot_amplitudes_live(
target_state: QubitState | None = None,
optimal_iteration: int | None = None,
) -> None:
amplitudes = statevector.data.real # Real part of amplitudes
amplitudes: npt.NDArray[np.float64] = statevector.data.real # Real part of amplitudes
mean = np.mean(amplitudes)
for bar, state, amp in zip(bars, basis_states, amplitudes, strict=False):
bar.set_height(amp)
if state == target_state:
if optimal_iteration is not None and iteration == optimal_iteration:
bar.set_color("green")
else:
bar.set_color("orange")
else:
bar.set_color("skyblue")
bar.set_color(get_bar_color(state, target_state, iteration, optimal_iteration))
ax.set_title(f"Iteration {iteration}: {step_label}")
ax.set_ylim(-1, 1)
# Remove previous mean line(s)
for l in ax.lines:
for l in ax.lines: # Remove previous mean line(s)
l.remove()
ax.axhline(mean, color="red", linestyle="--", label="Mean")
ax.axhline(float(mean), color="red", linestyle="--", label="Mean")
if not ax.get_legend():
ax.legend(loc="upper right")
@ -116,7 +127,7 @@ def main() -> None:
target_state = QubitState("1010")
n_qubits = len(target_state)
basis_states = [str(bit) for bit in all_states(n_qubits)]
optimal_iterations = floor(pi / 4 * sqrt(2**n_qubits))
optimal_iterations = optimal_grover_iterations(n_qubits)
plt.ion()
fig, ax = plt.subplots(figsize=(8, 3))
@ -126,28 +137,25 @@ def main() -> None:
ax.set_ylim(-1, 1)
ax.set_title("Grover Amplitudes")
def step_and_plot(
operation: Callable[[QuantumCircuit], None] | None,
step_label: str,
iteration: int,
) -> None:
if operation is not None:
operation(qc)
sv = Statevector.from_instruction(qc)
plot_amplitudes_live(ax, bars, sv, basis_states, step_label, iteration, target_state, optimal_iterations)
# Start with Hadamard
qc = QuantumCircuit(n_qubits)
qc.h(range(n_qubits))
sv = Statevector.from_instruction(qc)
plot_amplitudes_live(ax, bars, sv, basis_states, "Hadamard (Initialization)", 0, target_state, optimal_iterations)
step_and_plot(None, "Hadamard (Initialization)", 0)
iteration = 1
while plt.fignum_exists(fig.number):
# Oracle phase
oracle(qc, target_state)
sv = Statevector.from_instruction(qc)
plot_amplitudes_live(
ax, bars, sv, basis_states, "Oracle (Query Phase)", iteration, target_state, optimal_iterations
)
# Diffusion phase
diffusion(qc, n_qubits)
sv = Statevector.from_instruction(qc)
plot_amplitudes_live(
ax, bars, sv, basis_states, "Diffusion (Inversion Phase)", iteration, target_state, optimal_iterations
)
step_and_plot(lambda qc: oracle(qc, target_state), "Oracle (Query Phase)", iteration)
step_and_plot(lambda qc: diffusion(qc, n_qubits), "Diffusion (Inversion Phase)", iteration)
iteration += 1
plt.ioff()