feat: add step visualization

This commit is contained in:
Kristofers Solo 2025-04-22 15:38:39 +03:00
parent 2b705b0f8b
commit c928ad55a9
2 changed files with 81 additions and 32 deletions

View File

@ -11,9 +11,11 @@ from itertools import product
from math import floor, pi, sqrt from math import floor, pi, sqrt
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes from matplotlib.axes import Axes
from matplotlib.container import BarContainer
from qiskit import QuantumCircuit from qiskit import QuantumCircuit
from qiskit_aer import AerSimulator from qiskit.quantum_info import Statevector
from grovers_visualizer.gates import apply_phase_inversion, encode_target_state from grovers_visualizer.gates import apply_phase_inversion, encode_target_state
from grovers_visualizer.state import QubitState from grovers_visualizer.state import QubitState
@ -75,40 +77,78 @@ def all_states(n_qubits: int) -> Iterator[QubitState]:
yield QubitState("".join(bits)) yield QubitState("".join(bits))
def plot_amplitudes_live(
ax: Axes,
bars: BarContainer,
statevector: Statevector,
basis_states: list[str],
step_label: str,
iteration: int,
target_state: QubitState | None = None,
optimal_iteration: int | None = None,
) -> None:
amplitudes = statevector.data.real # Real part of amplitudes
mean = np.mean(amplitudes)
for bar, state, amp in zip(bars, basis_states, amplitudes, strict=False):
bar.set_height(amp)
if state == target_state:
if optimal_iteration is not None and iteration == optimal_iteration:
bar.set_color("green")
else:
bar.set_color("orange")
else:
bar.set_color("skyblue")
ax.set_title(f"Iteration {iteration}: {step_label}")
ax.set_ylim(-1, 1)
# Remove previous mean line(s)
for l in ax.lines:
l.remove()
ax.axhline(mean, color="red", linestyle="--", label="Mean")
if not ax.get_legend():
ax.legend(loc="upper right")
plt.pause(1)
def main() -> None: def main() -> None:
shots = 128 target_state = QubitState("1010")
target = QubitState("1010") n_qubits = len(target_state)
n_qubits = len(target) basis_states = [str(bit) for bit in all_states(n_qubits)]
optimal_iterations = floor(pi / 4 * sqrt(2**n_qubits))
qc = grover_search(target, iterations=1)
simulator = AerSimulator()
job = simulator.run(qc, shots=shots, memory=True)
result = job.result()
memory = result.get_memory(qc) # List of measurement results, one per shot
print(qc) # draw scheme
print(f"Target: {target}")
# Ensure all possible states are present in the bar chart
all_states = ["".join(bits) for bits in product("01", repeat=n_qubits)]
counts = dict.fromkeys(all_states, 0)
plt.ion() plt.ion()
_, ax = plt.subplots(figsize=(6, 2)) fig, ax = plt.subplots(figsize=(8, 3))
bars = ax.bar(all_states, [0] * len(all_states), color="skyblue") bars = ax.bar(basis_states, [0] * len(basis_states), color="skyblue")
ax.set_xlabel("Measured State") ax.set_xlabel("Basis State")
ax.set_ylabel("Counts") ax.set_ylabel("Real Amplitude")
ax.set_title(f"Measurement Variability for Target: {target}") ax.set_ylim(-1, 1)
ax.set_ylim(0, shots) ax.set_title("Grover Amplitudes")
for i, measured in enumerate(memory, 1): # Start with Hadamard
counts[measured] += 1 qc = QuantumCircuit(n_qubits)
for bar, state in zip(bars, all_states, strict=False): qc.h(range(n_qubits))
bar.set_height(counts[state]) sv = Statevector.from_instruction(qc)
bar.set_color("orange" if state == str(target) else "skyblue") plot_amplitudes_live(ax, bars, sv, basis_states, "Hadamard (Initialization)", 0, target_state, optimal_iterations)
ax.set_title(f"Measurement Variability (Shot {i}/{shots})\nTarget: {target}")
plt.pause(0.5) iteration = 1
while plt.fignum_exists(fig.number):
# Oracle phase
oracle(qc, target_state)
sv = Statevector.from_instruction(qc)
plot_amplitudes_live(
ax, bars, sv, basis_states, "Oracle (Query Phase)", iteration, target_state, optimal_iterations
)
# Diffusion phase
diffusion(qc, n_qubits)
sv = Statevector.from_instruction(qc)
plot_amplitudes_live(
ax, bars, sv, basis_states, "Diffusion (Inversion Phase)", iteration, target_state, optimal_iterations
)
iteration += 1
plt.ioff() plt.ioff()
plt.show() plt.show()

View File

@ -29,7 +29,16 @@ class QubitState:
def __eq__(self, value: object, /) -> bool: def __eq__(self, value: object, /) -> bool:
if isinstance(value, QubitState): if isinstance(value, QubitState):
return self.bits == value.bits return self.bits == value.bits
return False if isinstance(value, str):
return self.bits == value
return NotImplemented
def __lt__(self, value: object, /) -> bool:
if isinstance(value, QubitState):
return int(self.bits, 2) < int(value.bits, 2)
if isinstance(value, str) and all(b in "01" for b in value):
return int(self.bits, 2) < int(value, 2)
return NotImplemented
@override @override
def __hash__(self) -> int: def __hash__(self) -> int: