feat: add step visualization

This commit is contained in:
Kristofers Solo 2025-04-22 15:38:39 +03:00
parent 2b705b0f8b
commit c928ad55a9
2 changed files with 81 additions and 32 deletions

View File

@ -11,9 +11,11 @@ from itertools import product
from math import floor, pi, sqrt
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.container import BarContainer
from qiskit import QuantumCircuit
from qiskit_aer import AerSimulator
from qiskit.quantum_info import Statevector
from grovers_visualizer.gates import apply_phase_inversion, encode_target_state
from grovers_visualizer.state import QubitState
@ -75,40 +77,78 @@ def all_states(n_qubits: int) -> Iterator[QubitState]:
yield QubitState("".join(bits))
def plot_amplitudes_live(
ax: Axes,
bars: BarContainer,
statevector: Statevector,
basis_states: list[str],
step_label: str,
iteration: int,
target_state: QubitState | None = None,
optimal_iteration: int | None = None,
) -> None:
amplitudes = statevector.data.real # Real part of amplitudes
mean = np.mean(amplitudes)
for bar, state, amp in zip(bars, basis_states, amplitudes, strict=False):
bar.set_height(amp)
if state == target_state:
if optimal_iteration is not None and iteration == optimal_iteration:
bar.set_color("green")
else:
bar.set_color("orange")
else:
bar.set_color("skyblue")
ax.set_title(f"Iteration {iteration}: {step_label}")
ax.set_ylim(-1, 1)
# Remove previous mean line(s)
for l in ax.lines:
l.remove()
ax.axhline(mean, color="red", linestyle="--", label="Mean")
if not ax.get_legend():
ax.legend(loc="upper right")
plt.pause(1)
def main() -> None:
shots = 128
target = QubitState("1010")
n_qubits = len(target)
qc = grover_search(target, iterations=1)
simulator = AerSimulator()
job = simulator.run(qc, shots=shots, memory=True)
result = job.result()
memory = result.get_memory(qc) # List of measurement results, one per shot
print(qc) # draw scheme
print(f"Target: {target}")
# Ensure all possible states are present in the bar chart
all_states = ["".join(bits) for bits in product("01", repeat=n_qubits)]
counts = dict.fromkeys(all_states, 0)
target_state = QubitState("1010")
n_qubits = len(target_state)
basis_states = [str(bit) for bit in all_states(n_qubits)]
optimal_iterations = floor(pi / 4 * sqrt(2**n_qubits))
plt.ion()
_, ax = plt.subplots(figsize=(6, 2))
bars = ax.bar(all_states, [0] * len(all_states), color="skyblue")
ax.set_xlabel("Measured State")
ax.set_ylabel("Counts")
ax.set_title(f"Measurement Variability for Target: {target}")
ax.set_ylim(0, shots)
fig, ax = plt.subplots(figsize=(8, 3))
bars = ax.bar(basis_states, [0] * len(basis_states), color="skyblue")
ax.set_xlabel("Basis State")
ax.set_ylabel("Real Amplitude")
ax.set_ylim(-1, 1)
ax.set_title("Grover Amplitudes")
for i, measured in enumerate(memory, 1):
counts[measured] += 1
for bar, state in zip(bars, all_states, strict=False):
bar.set_height(counts[state])
bar.set_color("orange" if state == str(target) else "skyblue")
ax.set_title(f"Measurement Variability (Shot {i}/{shots})\nTarget: {target}")
plt.pause(0.5)
# Start with Hadamard
qc = QuantumCircuit(n_qubits)
qc.h(range(n_qubits))
sv = Statevector.from_instruction(qc)
plot_amplitudes_live(ax, bars, sv, basis_states, "Hadamard (Initialization)", 0, target_state, optimal_iterations)
iteration = 1
while plt.fignum_exists(fig.number):
# Oracle phase
oracle(qc, target_state)
sv = Statevector.from_instruction(qc)
plot_amplitudes_live(
ax, bars, sv, basis_states, "Oracle (Query Phase)", iteration, target_state, optimal_iterations
)
# Diffusion phase
diffusion(qc, n_qubits)
sv = Statevector.from_instruction(qc)
plot_amplitudes_live(
ax, bars, sv, basis_states, "Diffusion (Inversion Phase)", iteration, target_state, optimal_iterations
)
iteration += 1
plt.ioff()
plt.show()

View File

@ -29,7 +29,16 @@ class QubitState:
def __eq__(self, value: object, /) -> bool:
if isinstance(value, QubitState):
return self.bits == value.bits
return False
if isinstance(value, str):
return self.bits == value
return NotImplemented
def __lt__(self, value: object, /) -> bool:
if isinstance(value, QubitState):
return int(self.bits, 2) < int(value.bits, 2)
if isinstance(value, str) and all(b in "01" for b in value):
return int(self.bits, 2) < int(value, 2)
return NotImplemented
@override
def __hash__(self) -> int: