From 17ed8bfd53769fba312dbf64efa488bd2fd996f9 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Sat, 19 Apr 2025 16:24:02 +0300 Subject: [PATCH] refactor: update function names --- src/grovers_visualizer/gates.py | 22 +++++++++++++ src/grovers_visualizer/main.py | 57 ++++++++++++--------------------- 2 files changed, 42 insertions(+), 37 deletions(-) create mode 100644 src/grovers_visualizer/gates.py diff --git a/src/grovers_visualizer/gates.py b/src/grovers_visualizer/gates.py new file mode 100644 index 0000000..531ad3c --- /dev/null +++ b/src/grovers_visualizer/gates.py @@ -0,0 +1,22 @@ +from qiskit import QuantumCircuit + +from grovers_visualizer.state import QubitState + + +def encode_target_state(qc: QuantumCircuit, target_state: QubitState) -> None: + """Apply X gates to qubits where the target state bit is '0'.""" + for i, bit in enumerate(reversed(target_state)): + if bit == "0": + qc.x(i) + + +def apply_phase_inversion(qc: QuantumCircuit, n: int) -> None: + """Apply a multi-controlled phase inversion (Z) to the marked state.""" + if n == 1: + qc.z(0) + elif n == 2: + qc.cz(0, 1) + else: + qc.h(n - 1) + qc.mcx(list(range(n - 1)), n - 1) # multi-controlled X (Toffoli for 3 qubits) + qc.h(n - 1) diff --git a/src/grovers_visualizer/main.py b/src/grovers_visualizer/main.py index e0859e7..eb4ca52 100644 --- a/src/grovers_visualizer/main.py +++ b/src/grovers_visualizer/main.py @@ -7,65 +7,44 @@ using matplotlib. """ from itertools import product +from math import floor, pi, sqrt +from typing import Iterator import matplotlib.pyplot as plt -import numpy as np from matplotlib.axes import Axes from qiskit import QuantumCircuit from qiskit_aer import AerSimulator +from grovers_visualizer.gates import apply_phase_inversion, encode_target_state from grovers_visualizer.state import QubitState -def x(qc: QuantumCircuit, target_state: QubitState) -> None: - for i, bit in enumerate(reversed(target_state)): - if bit == "0": - qc.x(i) - - -def ccz(qc: QuantumCircuit, n: int) -> None: - """Multi-controlled Z (for 3 qubits, this is a CCZ)""" - if n == 1: - qc.z(0) - elif n == 2: - qc.cz(0, 1) - else: - qc.h(n - 1) - qc.mcx(list(range(n - 1)), n - 1) # multi-controlled X (Toffoli for 3 qubits) - qc.h(n - 1) - - def oracle(qc: QuantumCircuit, target_state: QubitState) -> None: + """Oracle that flips the sign of the target state.""" n = len(target_state) - - x(qc, target_state) - - ccz(qc, n) - - # Undo the X gates - x(qc, target_state) + encode_target_state(qc, target_state) + apply_phase_inversion(qc, n) + encode_target_state(qc, target_state) # Undo def diffusion(qc: QuantumCircuit, n: int) -> None: """Apply the Grovers diffusion operator""" - qc.h(range(n)) qc.x(range(n)) - - ccz(qc, n) - + apply_phase_inversion(qc, n) qc.x(range(n)) qc.h(range(n)) def grover_search(n: int, target_state: QubitState) -> QuantumCircuit: + """Construct a Grover search circuit for the given target state.""" qc = QuantumCircuit(n, n) qc.h(range(n)) num_states = 2**n - iterations = int(np.floor(np.pi / 4 * np.sqrt(num_states))) + iterations = floor(pi / 4 * sqrt(num_states)) for _ in range(iterations): oracle(qc, target_state) diffusion(qc, n) @@ -74,8 +53,8 @@ def grover_search(n: int, target_state: QubitState) -> QuantumCircuit: return qc -def plot_counts(ax: Axes, counts: dict[str, int], target_state: str) -> None: - """Create and display a bar chart for the measurement results.""" +def plot_counts(ax: Axes, counts: dict[str, int], target_state: QubitState) -> None: + """Display a bar chart for the measurement results.""" # Sort the states states = list(counts.keys()) @@ -89,17 +68,21 @@ def plot_counts(ax: Axes, counts: dict[str, int], target_state: str) -> None: ax.set_ylim(0, max(frequencies) * 1.2) +def all_states(n_qubits: int) -> Iterator[QubitState]: + """Generate all possible QubitStates for n_qubits.""" + for bits in product("01", repeat=n_qubits): + yield QubitState("".join(bits)) + + def main() -> None: n_qubits = 3 - combinations = product(["0", "1"], repeat=n_qubits) - states = ["".join(x) for x in combinations] shots = 1024 _, ax = plt.subplots(figsize=(8, 4)) plt.ion() - for state in states: - qc = grover_search(n_qubits, QubitState(state)) + for state in all_states(n_qubits): + qc = grover_search(n_qubits, state) print(qc.draw("text"))