feat(cli): add phase arg

This commit is contained in:
Kristofers Solo 2025-05-08 18:49:39 +03:00
parent 78952d9364
commit 30cc848ecd
Signed by: kristoferssolo
GPG Key ID: 74FF8144483D82C8
4 changed files with 39 additions and 17 deletions

View File

@ -1,3 +1,4 @@
import math
from argparse import ArgumentParser
from dataclasses import dataclass
@ -11,6 +12,7 @@ class Args:
iterations: int
speed: float
ui: bool
phase: float
def parse_args() -> Args:
@ -26,6 +28,7 @@ def parse_args() -> Args:
iterations=ns.iterations,
speed=ns.speed,
ui=ns.ui,
phase=ns.phase,
)
@ -61,4 +64,15 @@ def parse_cli(base_parser: ArgumentParser) -> None:
default=0.5,
help="Pause duration (seconds) between steps (deafult: 0.5)",
)
parser.add_argument(
"-p",
"--phase",
type=float,
default=math.pi,
help=(
"The phase φ (in radians) used for the oracle and diffusion steps. "
"Defaults to π, which implements the usual sign flip e^(iπ) = -1."
),
)
parser.add_argument("--ui", action="store_true", help="Run with DearPyGui UI")

View File

@ -1,35 +1,38 @@
import math
from qiskit import QuantumCircuit
from qiskit.circuit.library import PhaseGate
from .state import QubitState
def oracle(qc: QuantumCircuit, target_state: QubitState) -> None:
def oracle(qc: QuantumCircuit, target_state: QubitState, /, *, phase: float = math.pi) -> None:
"""Oracle that flips the sign of the target state."""
n = len(target_state)
encode_target_state(qc, target_state)
apply_phase_inversion(qc, n)
apply_phase_inversion(qc, n, phase=phase)
encode_target_state(qc, target_state) # Undo
def oracle_circuit(target: QubitState) -> QuantumCircuit:
def oracle_circuit(target: QubitState, /, *, phase: float = math.pi) -> QuantumCircuit:
n = len(target)
qc = QuantumCircuit(n)
oracle(qc, target)
oracle(qc, target, phase=phase)
return qc
def diffusion(qc: QuantumCircuit, n: int) -> None:
def diffusion(qc: QuantumCircuit, n: int, /, *, phase: float = math.pi) -> None:
"""Apply the Grovers diffusion operator."""
qc.h(range(n))
qc.x(range(n))
apply_phase_inversion(qc, n)
apply_phase_inversion(qc, n, phase=phase)
qc.x(range(n))
qc.h(range(n))
def diffusion_circuit(n: int) -> QuantumCircuit:
def diffusion_circuit(n: int, /, *, phase: float = math.pi) -> QuantumCircuit:
qc = QuantumCircuit(n)
diffusion(qc, n)
diffusion(qc, n, phase=phase)
return qc
@ -40,11 +43,10 @@ def encode_target_state(qc: QuantumCircuit, target_state: QubitState) -> None:
qc.x(i)
def apply_phase_inversion(qc: QuantumCircuit, n: int) -> None:
def apply_phase_inversion(qc: QuantumCircuit, n: int, /, *, phase: float = math.pi) -> None:
"""Apply a multi-controlled phase inversion (Z) to the marked state."""
if n == 1:
qc.z(0)
qc.p(phase, 0)
return
qc.h(n - 1)
qc.mcx(list(range(n - 1)), n - 1)
qc.h(n - 1)
mc_phase = PhaseGate(phase).control(n - 1)
qc.append(mc_phase, list(range(n)))

View File

@ -7,7 +7,7 @@ from .args import Args
def run_cli(args: Args) -> None:
vis = GroverVisualizer(args.target, pause=args.speed)
for it, sv in grover_evolver(vis.target, args.iterations):
for it, sv in grover_evolver(vis.target, args.iterations, phase=args.phase):
if not vis.is_running:
break
vis.update(it, sv)

View File

@ -1,3 +1,4 @@
import math
from collections.abc import Iterator
from itertools import count
@ -9,7 +10,12 @@ from grovers_visualizer.circuit import diffusion_circuit, oracle_circuit
from grovers_visualizer.state import QubitState
def grover_evolver(target: QubitState, max_iterations: int = 0) -> Iterator[tuple[int, Statevector]]:
def grover_evolver(
target: QubitState,
max_iterations: int = 0,
*,
phase: float = math.pi,
) -> Iterator[tuple[int, Statevector]]:
"""Yields (iteration, statevector) pairs.
- iteration=0 is the uniform-Hadamard initialization
@ -24,8 +30,8 @@ def grover_evolver(target: QubitState, max_iterations: int = 0) -> Iterator[tupl
sv = Statevector.from_instruction(qc)
yield 0, sv
oracle_op = Operator(oracle_circuit(target))
diffusion_op = Operator(diffusion_circuit(n_qubits))
oracle_op = Operator(oracle_circuit(target, phase=phase))
diffusion_op = Operator(diffusion_circuit(n_qubits, phase=phase))
iters = range(1, max_iterations + 1) if max_iterations > 0 else count(1)
for i in iters: