feat: implement single shot visualizer

This commit is contained in:
Kristofers Solo 2025-04-21 14:48:07 +03:00
parent cf7d7d1e40
commit 7ecd4fe13a

View File

@ -6,11 +6,10 @@ simulation using Qiskit's Aer simulator, and visualizes the results
using matplotlib. using matplotlib.
""" """
from collections.abc import Iterator
from itertools import product from itertools import product
from math import floor, pi, sqrt from math import floor, pi, sqrt
from typing import Iterator
import matplotlib.pyplot as plt
from matplotlib.axes import Axes from matplotlib.axes import Axes
from qiskit import QuantumCircuit from qiskit import QuantumCircuit
from qiskit_aer import AerSimulator from qiskit_aer import AerSimulator
@ -36,15 +35,16 @@ def diffusion(qc: QuantumCircuit, n: int) -> None:
qc.h(range(n)) qc.h(range(n))
def grover_search(n: int, target_state: QubitState) -> QuantumCircuit: def grover_search(target_state: QubitState, iterations: int | None = None) -> QuantumCircuit:
"""Construct a Grover search circuit for the given target state.""" """Construct a Grover search circuit for the given target state."""
n = len(target_state)
qc = QuantumCircuit(n, n) qc = QuantumCircuit(n, n)
qc.h(range(n)) qc.h(range(n))
num_states = 2**n if iterations is None or iterations < 0:
iterations = floor(pi / 4 * sqrt(2**n))
iterations = floor(pi / 4 * sqrt(num_states))
for _ in range(iterations): for _ in range(iterations):
oracle(qc, target_state) oracle(qc, target_state)
diffusion(qc, n) diffusion(qc, n)
@ -75,30 +75,49 @@ def all_states(n_qubits: int) -> Iterator[QubitState]:
def main() -> None: def main() -> None:
n_qubits = 3 shots = 20
shots = 1024 target = QubitState("11111111111111111")
n_qubits = len(target)
_, ax = plt.subplots(figsize=(8, 4))
plt.ion()
for state in all_states(n_qubits):
qc = grover_search(n_qubits, state)
print(qc.draw("text"))
qc = grover_search(target, iterations=4)
print(qc)
simulator = AerSimulator() simulator = AerSimulator()
job = simulator.run(qc, shots=shots) job = simulator.run(qc, shots=shots, memory=True)
result = job.result() result = job.result()
counts: dict[str, int] = result.get_counts(qc) memory = result.get_memory(qc) # List of measurement results, one per shot
sorted_counts = dict(sorted(counts.items(), key=lambda x: x[1], reverse=True)) counts = result.get_counts(qc) # Dict: state string -> count
print(counts)
print(f"Target: {state}") # print(qc.draw("text"))
print("\n".join(f"'{k}': {v}" for k, v in sorted_counts.items())) print(f"Target: {target}")
plot_counts(ax, sorted_counts, state) # Ensure all possible states are present in the bar chart
plt.pause(1) all_states = ["".join(bits) for bits in product("01", repeat=n_qubits)]
counts = dict.fromkeys(all_states, 0)
# print(counts)
plt.show() # plt.ion()
# _, ax = plt.subplots(figsize=(6, 2))
# bars = ax.bar(all_states, [0] * len(all_states), color="skyblue")
# ax.set_xlabel("Measured State")
# ax.set_ylabel("Counts")
# ax.set_title(f"Measurement Variability for Target: {target}")
# ax.set_ylim(0, shots)
for i, measured in enumerate(memory, 1):
pass
# print(measured)
# measured_be = measured[::-1]
# if measured_be in counts:
# counts[measured_be] += 1
# for bar, state in zip(bars, all_states, strict=False):
# bar.set_height(counts[state])
# bar.set_color("orange" if state == str(target) else "skyblue")
# ax.set_title(f"Measurement Variability for Target: {target} (Shot {i}/{shots})")
# plt.pause(1)
# plt.ioff()
# plt.show()
if __name__ == "__main__": if __name__ == "__main__":