refactor: update function names

This commit is contained in:
Kristofers Solo 2025-04-19 16:24:02 +03:00
parent 750a522069
commit 17ed8bfd53
2 changed files with 42 additions and 37 deletions

View File

@ -0,0 +1,22 @@
from qiskit import QuantumCircuit
from grovers_visualizer.state import QubitState
def encode_target_state(qc: QuantumCircuit, target_state: QubitState) -> None:
"""Apply X gates to qubits where the target state bit is '0'."""
for i, bit in enumerate(reversed(target_state)):
if bit == "0":
qc.x(i)
def apply_phase_inversion(qc: QuantumCircuit, n: int) -> None:
"""Apply a multi-controlled phase inversion (Z) to the marked state."""
if n == 1:
qc.z(0)
elif n == 2:
qc.cz(0, 1)
else:
qc.h(n - 1)
qc.mcx(list(range(n - 1)), n - 1) # multi-controlled X (Toffoli for 3 qubits)
qc.h(n - 1)

View File

@ -7,65 +7,44 @@ using matplotlib.
"""
from itertools import product
from math import floor, pi, sqrt
from typing import Iterator
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from qiskit import QuantumCircuit
from qiskit_aer import AerSimulator
from grovers_visualizer.gates import apply_phase_inversion, encode_target_state
from grovers_visualizer.state import QubitState
def x(qc: QuantumCircuit, target_state: QubitState) -> None:
for i, bit in enumerate(reversed(target_state)):
if bit == "0":
qc.x(i)
def ccz(qc: QuantumCircuit, n: int) -> None:
"""Multi-controlled Z (for 3 qubits, this is a CCZ)"""
if n == 1:
qc.z(0)
elif n == 2:
qc.cz(0, 1)
else:
qc.h(n - 1)
qc.mcx(list(range(n - 1)), n - 1) # multi-controlled X (Toffoli for 3 qubits)
qc.h(n - 1)
def oracle(qc: QuantumCircuit, target_state: QubitState) -> None:
"""Oracle that flips the sign of the target state."""
n = len(target_state)
x(qc, target_state)
ccz(qc, n)
# Undo the X gates
x(qc, target_state)
encode_target_state(qc, target_state)
apply_phase_inversion(qc, n)
encode_target_state(qc, target_state) # Undo
def diffusion(qc: QuantumCircuit, n: int) -> None:
"""Apply the Grovers diffusion operator"""
qc.h(range(n))
qc.x(range(n))
ccz(qc, n)
apply_phase_inversion(qc, n)
qc.x(range(n))
qc.h(range(n))
def grover_search(n: int, target_state: QubitState) -> QuantumCircuit:
"""Construct a Grover search circuit for the given target state."""
qc = QuantumCircuit(n, n)
qc.h(range(n))
num_states = 2**n
iterations = int(np.floor(np.pi / 4 * np.sqrt(num_states)))
iterations = floor(pi / 4 * sqrt(num_states))
for _ in range(iterations):
oracle(qc, target_state)
diffusion(qc, n)
@ -74,8 +53,8 @@ def grover_search(n: int, target_state: QubitState) -> QuantumCircuit:
return qc
def plot_counts(ax: Axes, counts: dict[str, int], target_state: str) -> None:
"""Create and display a bar chart for the measurement results."""
def plot_counts(ax: Axes, counts: dict[str, int], target_state: QubitState) -> None:
"""Display a bar chart for the measurement results."""
# Sort the states
states = list(counts.keys())
@ -89,17 +68,21 @@ def plot_counts(ax: Axes, counts: dict[str, int], target_state: str) -> None:
ax.set_ylim(0, max(frequencies) * 1.2)
def all_states(n_qubits: int) -> Iterator[QubitState]:
"""Generate all possible QubitStates for n_qubits."""
for bits in product("01", repeat=n_qubits):
yield QubitState("".join(bits))
def main() -> None:
n_qubits = 3
combinations = product(["0", "1"], repeat=n_qubits)
states = ["".join(x) for x in combinations]
shots = 1024
_, ax = plt.subplots(figsize=(8, 4))
plt.ion()
for state in states:
qc = grover_search(n_qubits, QubitState(state))
for state in all_states(n_qubits):
qc = grover_search(n_qubits, state)
print(qc.draw("text"))