feat: implement single shot visualizer

This commit is contained in:
Kristofers Solo 2025-04-21 14:48:07 +03:00
parent cf7d7d1e40
commit 7ecd4fe13a

View File

@ -6,11 +6,10 @@ simulation using Qiskit's Aer simulator, and visualizes the results
using matplotlib.
"""
from collections.abc import Iterator
from itertools import product
from math import floor, pi, sqrt
from typing import Iterator
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from qiskit import QuantumCircuit
from qiskit_aer import AerSimulator
@ -36,15 +35,16 @@ def diffusion(qc: QuantumCircuit, n: int) -> None:
qc.h(range(n))
def grover_search(n: int, target_state: QubitState) -> QuantumCircuit:
def grover_search(target_state: QubitState, iterations: int | None = None) -> QuantumCircuit:
"""Construct a Grover search circuit for the given target state."""
n = len(target_state)
qc = QuantumCircuit(n, n)
qc.h(range(n))
num_states = 2**n
if iterations is None or iterations < 0:
iterations = floor(pi / 4 * sqrt(2**n))
iterations = floor(pi / 4 * sqrt(num_states))
for _ in range(iterations):
oracle(qc, target_state)
diffusion(qc, n)
@ -75,30 +75,49 @@ def all_states(n_qubits: int) -> Iterator[QubitState]:
def main() -> None:
n_qubits = 3
shots = 1024
_, ax = plt.subplots(figsize=(8, 4))
plt.ion()
for state in all_states(n_qubits):
qc = grover_search(n_qubits, state)
print(qc.draw("text"))
shots = 20
target = QubitState("11111111111111111")
n_qubits = len(target)
qc = grover_search(target, iterations=4)
print(qc)
simulator = AerSimulator()
job = simulator.run(qc, shots=shots)
job = simulator.run(qc, shots=shots, memory=True)
result = job.result()
counts: dict[str, int] = result.get_counts(qc)
sorted_counts = dict(sorted(counts.items(), key=lambda x: x[1], reverse=True))
memory = result.get_memory(qc) # List of measurement results, one per shot
counts = result.get_counts(qc) # Dict: state string -> count
print(counts)
print(f"Target: {state}")
print("\n".join(f"'{k}': {v}" for k, v in sorted_counts.items()))
# print(qc.draw("text"))
print(f"Target: {target}")
plot_counts(ax, sorted_counts, state)
plt.pause(1)
# Ensure all possible states are present in the bar chart
all_states = ["".join(bits) for bits in product("01", repeat=n_qubits)]
counts = dict.fromkeys(all_states, 0)
# print(counts)
plt.show()
# plt.ion()
# _, ax = plt.subplots(figsize=(6, 2))
# bars = ax.bar(all_states, [0] * len(all_states), color="skyblue")
# ax.set_xlabel("Measured State")
# ax.set_ylabel("Counts")
# ax.set_title(f"Measurement Variability for Target: {target}")
# ax.set_ylim(0, shots)
for i, measured in enumerate(memory, 1):
pass
# print(measured)
# measured_be = measured[::-1]
# if measured_be in counts:
# counts[measured_be] += 1
# for bar, state in zip(bars, all_states, strict=False):
# bar.set_height(counts[state])
# bar.set_color("orange" if state == str(target) else "skyblue")
# ax.set_title(f"Measurement Variability for Target: {target} (Shot {i}/{shots})")
# plt.pause(1)
# plt.ioff()
# plt.show()
if __name__ == "__main__":