From 4a17ccf608f1c5047e87eece66575d530ab2e059 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Fri, 18 Apr 2025 16:11:24 +0300 Subject: [PATCH] feat(qubit): add custom qubit state type --- src/grovers_visualizer/main.py | 10 +++++--- src/grovers_visualizer/state.py | 45 +++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/src/grovers_visualizer/main.py b/src/grovers_visualizer/main.py index 897c842..61e3a9d 100644 --- a/src/grovers_visualizer/main.py +++ b/src/grovers_visualizer/main.py @@ -14,8 +14,10 @@ from matplotlib.axes import Axes from qiskit import QuantumCircuit from qiskit_aer import AerSimulator +from grovers_visualizer.state import QubitState -def x(qc: QuantumCircuit, target_state: str) -> None: + +def x(qc: QuantumCircuit, target_state: QubitState) -> None: for i, bit in enumerate(reversed(target_state)): if bit == "0": qc.x(i) @@ -33,7 +35,7 @@ def ccz(qc: QuantumCircuit, n: int) -> None: qc.h(n - 1) -def oracule(qc: QuantumCircuit, target_state: str) -> None: +def oracule(qc: QuantumCircuit, target_state: QubitState) -> None: n = len(target_state) x(qc, target_state) @@ -56,7 +58,7 @@ def diffusion(qc: QuantumCircuit, n: int) -> None: qc.h(range(n)) -def grover_search(n: int, target_state: str) -> QuantumCircuit: +def grover_search(n: int, target_state: QubitState) -> QuantumCircuit: qc = QuantumCircuit(n, n) qc.h(range(n)) @@ -97,7 +99,7 @@ def main() -> None: plt.ion() for state in states: - qc = grover_search(n_qubits, state) + qc = grover_search(n_qubits, QubitState(state)) print(qc.draw("text")) diff --git a/src/grovers_visualizer/state.py b/src/grovers_visualizer/state.py index e69de29..1866963 100644 --- a/src/grovers_visualizer/state.py +++ b/src/grovers_visualizer/state.py @@ -0,0 +1,45 @@ +from collections.abc import Iterator +from typing import Final, Self, override + + +class QubitState: + def __init__(self, bits: str) -> None: + if not all(b in "01" for b in bits): + raise ValueError(f"{self.__class__.__name__} must be a string of '0' and '1'") + self._bits: Final[str] = bits + + @property + def bits(self) -> str: + return self._bits + + @classmethod + def from_int(cls, value: int, num_qubits: int) -> Self: + bits = format(value, f"0{num_qubits}b") + return cls(bits) + + @override + def __str__(self) -> str: + return self._bits + + @override + def __repr__(self) -> str: + return f"{self.__class__.__name__}('{self.bits}')" + + @override + def __eq__(self, value: object, /) -> bool: + if isinstance(value, QubitState): + return self.bits == value.bits + return False + + @override + def __hash__(self) -> int: + return hash(self.bits) + + def __len__(self) -> int: + return len(self.bits) + + def __getitem__(self, idx: int | slice) -> str: + return self.bits[idx] + + def __iter__(self) -> Iterator[str]: + return iter(self.bits)