feat: implement grovers search

This commit is contained in:
Kristofers Solo 2025-04-17 15:36:54 +03:00
parent edddeec89e
commit f86fab1ae1

62
main.py
View File

@ -9,28 +9,73 @@ using matplotlib.
from itertools import product
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from qiskit import QuantumCircuit
from qiskit_aer import AerSimulator
def grover_search(n: int) -> QuantumCircuit:
def x(qc: QuantumCircuit, target_state: str) -> None:
for i, bit in enumerate(reversed(target_state)):
if bit == "0":
qc.x(i)
def ccz(qc: QuantumCircuit, n: int) -> None:
"""Multi-controlled Z (for 3 qubits, this is a CCZ)"""
if n == 1:
qc.z(0)
elif n == 2:
qc.cz(0, 1)
else:
qc.h(n - 1)
qc.mcx(list(range(n - 1)), n - 1) # multi-controlled X (Toffoli for 3 qubits)
qc.h(n - 1)
def oracule(qc: QuantumCircuit, target_state: str) -> None:
n = len(target_state)
x(qc, target_state)
ccz(qc, n)
# Undo the X gates
x(qc, target_state)
def diffusion(qc: QuantumCircuit, n: int) -> None:
"""Apply the Grovers diffusion operator"""
qc.h(range(n))
qc.x(range(n))
ccz(qc, n)
qc.x(range(n))
qc.h(range(n))
def grover_search(n: int, target_state: str) -> QuantumCircuit:
qc = QuantumCircuit(n, n)
qc.h(range(n))
num_states = 2**n
iterations = int(np.floor(np.pi / 4 * np.sqrt(num_states)))
for _ in range(iterations):
oracule(qc, target_state)
diffusion(qc, n)
qc.measure(range(n), range(n))
return qc
def plot_counts(ax: Axes, counts: dict[str, int], target_state: str) -> None:
"""Create and display a bar chart for the measurement results.
"""Create and display a bar chart for the measurement results."""
Parameters:
counts - A dictionary mapping output states to counts.
target_state - The target state used in the Grover circuit.
"""
# Sort the states (optional: you can sort by state or by count)
# Sort the states
states = list(counts.keys())
frequencies = [counts[s] for s in states]
@ -52,7 +97,7 @@ def main() -> None:
plt.ion()
for state in states:
qc = grover_search(n_qubits)
qc = grover_search(n_qubits, state)
print(qc.draw("text"))
@ -68,7 +113,6 @@ def main() -> None:
plot_counts(ax, sorted_counts, state)
plt.pause(1)
# plt.ioff() # Do not close automatically
plt.show()