diff --git a/src/grovers_visualizer/main.py b/src/grovers_visualizer/main.py index eb4ca52..4f3215f 100644 --- a/src/grovers_visualizer/main.py +++ b/src/grovers_visualizer/main.py @@ -6,11 +6,10 @@ simulation using Qiskit's Aer simulator, and visualizes the results using matplotlib. """ +from collections.abc import Iterator from itertools import product from math import floor, pi, sqrt -from typing import Iterator -import matplotlib.pyplot as plt from matplotlib.axes import Axes from qiskit import QuantumCircuit from qiskit_aer import AerSimulator @@ -36,15 +35,16 @@ def diffusion(qc: QuantumCircuit, n: int) -> None: qc.h(range(n)) -def grover_search(n: int, target_state: QubitState) -> QuantumCircuit: +def grover_search(target_state: QubitState, iterations: int | None = None) -> QuantumCircuit: """Construct a Grover search circuit for the given target state.""" + n = len(target_state) qc = QuantumCircuit(n, n) qc.h(range(n)) - num_states = 2**n + if iterations is None or iterations < 0: + iterations = floor(pi / 4 * sqrt(2**n)) - iterations = floor(pi / 4 * sqrt(num_states)) for _ in range(iterations): oracle(qc, target_state) diffusion(qc, n) @@ -75,30 +75,49 @@ def all_states(n_qubits: int) -> Iterator[QubitState]: def main() -> None: - n_qubits = 3 - shots = 1024 + shots = 20 + target = QubitState("11111111111111111") + n_qubits = len(target) - _, ax = plt.subplots(figsize=(8, 4)) - plt.ion() + qc = grover_search(target, iterations=4) + print(qc) + simulator = AerSimulator() + job = simulator.run(qc, shots=shots, memory=True) + result = job.result() + memory = result.get_memory(qc) # List of measurement results, one per shot + counts = result.get_counts(qc) # Dict: state string -> count + print(counts) - for state in all_states(n_qubits): - qc = grover_search(n_qubits, state) + # print(qc.draw("text")) + print(f"Target: {target}") - print(qc.draw("text")) + # Ensure all possible states are present in the bar chart + all_states = ["".join(bits) for bits in product("01", repeat=n_qubits)] + counts = dict.fromkeys(all_states, 0) + # print(counts) - simulator = AerSimulator() - job = simulator.run(qc, shots=shots) - result = job.result() - counts: dict[str, int] = result.get_counts(qc) - sorted_counts = dict(sorted(counts.items(), key=lambda x: x[1], reverse=True)) + # plt.ion() + # _, ax = plt.subplots(figsize=(6, 2)) + # bars = ax.bar(all_states, [0] * len(all_states), color="skyblue") + # ax.set_xlabel("Measured State") + # ax.set_ylabel("Counts") + # ax.set_title(f"Measurement Variability for Target: {target}") + # ax.set_ylim(0, shots) - print(f"Target: {state}") - print("\n".join(f"'{k}': {v}" for k, v in sorted_counts.items())) + for i, measured in enumerate(memory, 1): + pass + # print(measured) + # measured_be = measured[::-1] + # if measured_be in counts: + # counts[measured_be] += 1 + # for bar, state in zip(bars, all_states, strict=False): + # bar.set_height(counts[state]) + # bar.set_color("orange" if state == str(target) else "skyblue") + # ax.set_title(f"Measurement Variability for Target: {target} (Shot {i}/{shots})") + # plt.pause(1) - plot_counts(ax, sorted_counts, state) - plt.pause(1) - - plt.show() + # plt.ioff() + # plt.show() if __name__ == "__main__":