refactor(ket): use tuple[int] instead of str

This commit is contained in:
Kristofers Solo 2025-04-23 09:08:37 +03:00
parent 6e2fc0c2d5
commit 51a5f91410
5 changed files with 44 additions and 29 deletions

View File

@ -6,7 +6,7 @@ from .state import QubitState
def encode_target_state(qc: QuantumCircuit, target_state: QubitState) -> None:
"""Apply X gates to qubits where the target state bit is '0'."""
for i, bit in enumerate(reversed(target_state)):
if bit == "0":
if bit == 0:
qc.x(i)

View File

@ -10,6 +10,7 @@ from math import asin, sqrt
from typing import TYPE_CHECKING, Callable
import matplotlib.pyplot as plt
from matplotlib.backend_bases import KeyEvent
from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector
@ -34,10 +35,8 @@ def main() -> None:
state_angle = 0.5 * theta
plt.ion()
fig: Figure
ax_bar: Axes
ax_circle: Axes
fig, (ax_bar, ax_circle) = plt.subplots(1, 2, width_ratios=(3, 1), figsize=(12, 4))
subplt: tuple[Figure, tuple[Axes, Axes]] = plt.subplots(1, 2, width_ratios=(3, 1), figsize=(12, 4))
fig, (ax_bar, ax_circle) = subplt
bars = ax_bar.bar(basis_states, [0] * len(basis_states), color="skyblue")
ax_bar.set_ylim(-1, 1)
ax_bar.set_title("Amplitudes (example)")
@ -63,7 +62,7 @@ def main() -> None:
iteration = 1
running = True
def on_key(event) -> None:
def on_key(event: KeyEvent) -> None:
nonlocal running
if event.key == "q":
running = False

View File

@ -34,7 +34,7 @@ def parse_args() -> Args:
)
ns = parser.parse_args()
return Args(
target=QubitState(ns.target),
target=QubitState.from_str(ns.target),
iterations=ns.iterations,
speed=ns.speed,
)

View File

@ -1,54 +1,70 @@
from collections.abc import Iterator
from typing import Final, Self, override
from collections.abc import Iterable, Iterator
from typing import Self, override
class QubitState:
def __init__(self, bits: str) -> None:
if not all(b in "01" for b in bits):
raise ValueError(f"{self.__class__.__name__} must be a string of '0' and '1'")
self._bits: Final[str] = bits
def __init__(self, bits: Iterable[int]) -> None:
bits_tuple = tuple(bits) # Convert to not consume it
if not all(b in (0, 1) for b in bits_tuple):
raise ValueError(f"{self.__class__.__name__} must be a tuple of `0`s and `1`s")
self._bits: tuple[int, ...] = tuple(bits_tuple)
@property
def bits(self) -> str:
def bits(self) -> tuple[int, ...]:
return self._bits
@property
def bitsring(self) -> str:
return "".join(str(b) for b in self._bits)
@classmethod
def from_str(cls, s: str) -> Self:
return cls(int(b) for b in s)
@classmethod
def from_int(cls, value: int, num_qubits: int) -> Self:
bits = format(value, f"0{num_qubits}b")
bits = (int(x) for x in format(value, f"0{num_qubits}b"))
return cls(bits)
@override
def __str__(self) -> str:
return self._bits
return self.bitsring
@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}('{self.bits}')"
return f"{self.__class__.__name__}('{self.bitsring}')"
@override
def __eq__(self, value: object, /) -> bool:
if isinstance(value, QubitState):
return self.bits == value.bits
return self.bitsring == value.bitsring
if isinstance(value, str):
return self.bits == value
return self.bitsring == value
if isinstance(value, (list, tuple)):
return self.bits == tuple(value)
return NotImplemented
def __lt__(self, value: object, /) -> bool:
if isinstance(value, QubitState):
return int(self.bits, 2) < int(value.bits, 2)
return int(self.bitsring, 2) < int(value.bitsring, 2)
if isinstance(value, str) and all(b in "01" for b in value):
return int(self.bits, 2) < int(value, 2)
return int(self.bitsring, 2) < int(value, 2)
if isinstance(value, (list, tuple)):
return self.bits < tuple(value)
return NotImplemented
@override
def __hash__(self) -> int:
return hash(self.bits)
return hash(self.bitsring)
def __len__(self) -> int:
return len(self.bits)
return len(self._bits)
def __getitem__(self, idx: int | slice) -> str:
return self.bits[idx]
def __getitem__(self, idx: int | slice) -> int | tuple[int, ...]:
return self._bits[idx]
def __iter__(self) -> Iterator[str]:
return iter(self.bits)
def __iter__(self) -> Iterator[int]:
return iter(self._bits)
Ket = QubitState

View File

@ -7,8 +7,8 @@ from .state import QubitState
def all_states(n_qubits: int) -> Iterator[QubitState]:
"""Generate all possible QubitStates for n_qubits."""
for bits in product("01", repeat=n_qubits):
yield QubitState("".join(bits))
for bits in product((0, 1), repeat=n_qubits):
yield QubitState(bits)
def optimal_grover_iterations(n_qubits: int) -> int: