From 0d4ab8aab762d2789d91ad73bf0badf3facf8466 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Fri, 5 Jan 2024 18:04:17 +0200 Subject: [PATCH] refactor(ai): add `Optional` fix fix --- src/ai/fitness/fitness.py | 4 ---- src/ai/fitness/peaks.py | 22 ++++++++++++++++++---- src/ai/fitness/transitions.py | 13 +++++++++---- tests/ai/test_fitness.py | 6 +++--- 4 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/ai/fitness/fitness.py b/src/ai/fitness/fitness.py index d45a5f4..2f769e9 100644 --- a/src/ai/fitness/fitness.py +++ b/src/ai/fitness/fitness.py @@ -13,10 +13,6 @@ def calculate_fitness(game: Game) -> float: return fitness -def get_holes(field: np.ndarray) -> int: - pass - - def get_wells(field: np.ndarray) -> int: pass diff --git a/src/ai/fitness/peaks.py b/src/ai/fitness/peaks.py index 218945c..ba29f2b 100644 --- a/src/ai/fitness/peaks.py +++ b/src/ai/fitness/peaks.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np from ai.log import log @@ -8,9 +10,21 @@ def get_peaks(field: np.ndarray) -> np.ndarray: return peaks.max(axis=0) -def get_peaks_max(field: np.ndarray) -> int: - return int(np.max(get_peaks(field))) +def get_peaks_max( + peaks: Optional[np.ndarray], field: Optional[np.ndarray] = None +) -> int: + if peaks is None and field is None: + raise ValueError("peaks and field cannot both be None") + elif peaks is None: + peaks = get_peaks(field) + return int(np.max(peaks)) -def get_peaks_sum(field: np.ndarray) -> int: - return np.sum(get_peaks(field)) +def get_peaks_sum( + peaks: Optional[np.ndarray], field: Optional[np.ndarray] = None +) -> int: + if peaks is None and field is None: + raise ValueError("peaks and field cannot both be None") + elif peaks is None: + peaks = get_peaks(field) + return np.sum(peaks) diff --git a/src/ai/fitness/transitions.py b/src/ai/fitness/transitions.py index 89cb0db..c8b3258 100644 --- a/src/ai/fitness/transitions.py +++ b/src/ai/fitness/transitions.py @@ -1,10 +1,13 @@ +from typing import Optional + import numpy as np from .peaks import get_peaks, get_peaks_max -def get_row_transition(field: np.ndarray) -> int: - highest_peak = get_peaks_max(field) +def get_row_transition(field: np.ndarray, highest_peak: Optional[int] = None) -> int: + if highest_peak is None: + highest_peak = get_peaks_max(None, field) rows_to_check = slice(int(field.shape[0] - highest_peak), field.shape[0]) transitions = np.sum(field[rows_to_check, 1:] != field[rows_to_check, :-1]) @@ -12,9 +15,11 @@ def get_row_transition(field: np.ndarray) -> int: return int(transitions) -def get_col_transition(field: np.ndarray) -> int: +def get_col_transition(field: np.ndarray, peaks: Optional[np.ndarray] = None) -> int: + if peaks is None: + peaks = get_peaks(field) + transitions_sum = 0 - peaks = get_peaks(field) for col in range(field.shape[1]): if peaks[col] <= 1: diff --git a/tests/ai/test_fitness.py b/tests/ai/test_fitness.py index 9c53c1b..4f1f8f4 100644 --- a/tests/ai/test_fitness.py +++ b/tests/ai/test_fitness.py @@ -32,7 +32,7 @@ class TestFitness(unittest.TestCase): def test_get_peaks_sum(self) -> None: answers: tuple[int] = (11, 0, 2) for field, answer in zip(self.fields, answers): - self.assertEqual(get_peaks_sum(field), answer) + self.assertEqual(get_peaks_sum(None, field), answer) def test_get_row_transistions(self): answers = (8, 0, 2) @@ -45,6 +45,6 @@ class TestFitness(unittest.TestCase): self.assertEqual(get_col_transition(field), answer) def test_get_bumpiness(self): - answers = (8, 0, 2) + answers = (5, 0, 4) for field, answer in zip(self.fields, answers): - self.assertEqual(get_bumpiness(field), answer) + self.assertEqual(get_bumpiness(None, field), answer)