From 51a5f91410d4f59600a89731e8fbb40ae54225ba Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Wed, 23 Apr 2025 09:08:37 +0300 Subject: [PATCH] refactor(ket): use `tuple[int]` instead of `str` --- src/grovers_visualizer/gates.py | 2 +- src/grovers_visualizer/main.py | 9 +++--- src/grovers_visualizer/parse.py | 2 +- src/grovers_visualizer/state.py | 56 +++++++++++++++++++++------------ src/grovers_visualizer/utils.py | 4 +-- 5 files changed, 44 insertions(+), 29 deletions(-) diff --git a/src/grovers_visualizer/gates.py b/src/grovers_visualizer/gates.py index 4b78f1b..2a3c056 100644 --- a/src/grovers_visualizer/gates.py +++ b/src/grovers_visualizer/gates.py @@ -6,7 +6,7 @@ from .state import QubitState def encode_target_state(qc: QuantumCircuit, target_state: QubitState) -> None: """Apply X gates to qubits where the target state bit is '0'.""" for i, bit in enumerate(reversed(target_state)): - if bit == "0": + if bit == 0: qc.x(i) diff --git a/src/grovers_visualizer/main.py b/src/grovers_visualizer/main.py index 433644d..a7da90a 100644 --- a/src/grovers_visualizer/main.py +++ b/src/grovers_visualizer/main.py @@ -10,6 +10,7 @@ from math import asin, sqrt from typing import TYPE_CHECKING, Callable import matplotlib.pyplot as plt +from matplotlib.backend_bases import KeyEvent from qiskit import QuantumCircuit from qiskit.quantum_info import Statevector @@ -34,10 +35,8 @@ def main() -> None: state_angle = 0.5 * theta plt.ion() - fig: Figure - ax_bar: Axes - ax_circle: Axes - fig, (ax_bar, ax_circle) = plt.subplots(1, 2, width_ratios=(3, 1), figsize=(12, 4)) + subplt: tuple[Figure, tuple[Axes, Axes]] = plt.subplots(1, 2, width_ratios=(3, 1), figsize=(12, 4)) + fig, (ax_bar, ax_circle) = subplt bars = ax_bar.bar(basis_states, [0] * len(basis_states), color="skyblue") ax_bar.set_ylim(-1, 1) ax_bar.set_title("Amplitudes (example)") @@ -63,7 +62,7 @@ def main() -> None: iteration = 1 running = True - def on_key(event) -> None: + def on_key(event: KeyEvent) -> None: nonlocal running if event.key == "q": running = False diff --git a/src/grovers_visualizer/parse.py b/src/grovers_visualizer/parse.py index 779a46d..c8c2d88 100644 --- a/src/grovers_visualizer/parse.py +++ b/src/grovers_visualizer/parse.py @@ -34,7 +34,7 @@ def parse_args() -> Args: ) ns = parser.parse_args() return Args( - target=QubitState(ns.target), + target=QubitState.from_str(ns.target), iterations=ns.iterations, speed=ns.speed, ) diff --git a/src/grovers_visualizer/state.py b/src/grovers_visualizer/state.py index 9909e07..5f71729 100644 --- a/src/grovers_visualizer/state.py +++ b/src/grovers_visualizer/state.py @@ -1,54 +1,70 @@ -from collections.abc import Iterator -from typing import Final, Self, override +from collections.abc import Iterable, Iterator +from typing import Self, override class QubitState: - def __init__(self, bits: str) -> None: - if not all(b in "01" for b in bits): - raise ValueError(f"{self.__class__.__name__} must be a string of '0' and '1'") - self._bits: Final[str] = bits + def __init__(self, bits: Iterable[int]) -> None: + bits_tuple = tuple(bits) # Convert to not consume it + if not all(b in (0, 1) for b in bits_tuple): + raise ValueError(f"{self.__class__.__name__} must be a tuple of `0`s and `1`s") + self._bits: tuple[int, ...] = tuple(bits_tuple) @property - def bits(self) -> str: + def bits(self) -> tuple[int, ...]: return self._bits + @property + def bitsring(self) -> str: + return "".join(str(b) for b in self._bits) + + @classmethod + def from_str(cls, s: str) -> Self: + return cls(int(b) for b in s) + @classmethod def from_int(cls, value: int, num_qubits: int) -> Self: - bits = format(value, f"0{num_qubits}b") + bits = (int(x) for x in format(value, f"0{num_qubits}b")) return cls(bits) @override def __str__(self) -> str: - return self._bits + return self.bitsring @override def __repr__(self) -> str: - return f"{self.__class__.__name__}('{self.bits}')" + return f"{self.__class__.__name__}('{self.bitsring}')" @override def __eq__(self, value: object, /) -> bool: if isinstance(value, QubitState): - return self.bits == value.bits + return self.bitsring == value.bitsring if isinstance(value, str): - return self.bits == value + return self.bitsring == value + if isinstance(value, (list, tuple)): + return self.bits == tuple(value) return NotImplemented def __lt__(self, value: object, /) -> bool: if isinstance(value, QubitState): - return int(self.bits, 2) < int(value.bits, 2) + return int(self.bitsring, 2) < int(value.bitsring, 2) if isinstance(value, str) and all(b in "01" for b in value): - return int(self.bits, 2) < int(value, 2) + return int(self.bitsring, 2) < int(value, 2) + if isinstance(value, (list, tuple)): + return self.bits < tuple(value) return NotImplemented @override def __hash__(self) -> int: - return hash(self.bits) + return hash(self.bitsring) def __len__(self) -> int: - return len(self.bits) + return len(self._bits) - def __getitem__(self, idx: int | slice) -> str: - return self.bits[idx] + def __getitem__(self, idx: int | slice) -> int | tuple[int, ...]: + return self._bits[idx] - def __iter__(self) -> Iterator[str]: - return iter(self.bits) + def __iter__(self) -> Iterator[int]: + return iter(self._bits) + + +Ket = QubitState diff --git a/src/grovers_visualizer/utils.py b/src/grovers_visualizer/utils.py index 5ce0d5c..419c7c6 100644 --- a/src/grovers_visualizer/utils.py +++ b/src/grovers_visualizer/utils.py @@ -7,8 +7,8 @@ from .state import QubitState def all_states(n_qubits: int) -> Iterator[QubitState]: """Generate all possible QubitStates for n_qubits.""" - for bits in product("01", repeat=n_qubits): - yield QubitState("".join(bits)) + for bits in product((0, 1), repeat=n_qubits): + yield QubitState(bits) def optimal_grover_iterations(n_qubits: int) -> int: