From dec7c4d0e0b116199e16fadc0c824b16cce88e5b Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Fri, 5 Jan 2024 17:43:14 +0200 Subject: [PATCH] feat(ai): add `get_row_transition` --- src/ai/fitness/fitness.py | 8 ------- src/ai/fitness/transitions.py | 16 +++++++++++++ tests/ai/test_fitness.py | 44 +++++++++++++++++------------------ 3 files changed, 38 insertions(+), 30 deletions(-) create mode 100644 src/ai/fitness/transitions.py diff --git a/src/ai/fitness/fitness.py b/src/ai/fitness/fitness.py index dd8b6ef..23d9a15 100644 --- a/src/ai/fitness/fitness.py +++ b/src/ai/fitness/fitness.py @@ -13,14 +13,6 @@ def calculate_fitness(game: Game) -> float: return fitness -def get_row_transitions(field: np.ndarray) -> int: - pass - - -def get_col_transitions(field: np.ndarray) -> int: - pass - - def get_bumpiness(field: np.ndarray) -> int: pass diff --git a/src/ai/fitness/transitions.py b/src/ai/fitness/transitions.py new file mode 100644 index 0000000..ce2af02 --- /dev/null +++ b/src/ai/fitness/transitions.py @@ -0,0 +1,16 @@ +import numpy as np + +from .peaks import get_peaks_max + + +def get_row_transitions(field: np.ndarray) -> int: + highest_peak = get_peaks_max(field) + + rows_to_check = slice(int(field.shape[0] - highest_peak), field.shape[0]) + transitions = np.sum(field[rows_to_check, 1:] != field[rows_to_check, :-1]) + + return int(transitions) + + +def get_col_transitions(field: np.ndarray) -> int: + pass diff --git a/tests/ai/test_fitness.py b/tests/ai/test_fitness.py index 984bfb4..e502b97 100644 --- a/tests/ai/test_fitness.py +++ b/tests/ai/test_fitness.py @@ -6,30 +6,30 @@ from ai.fitness.transitions import get_row_transitions class TestFitness(unittest.TestCase): - def test_get_peaks(self) -> None: - field = np.array( - [ - [0, 1, 0, 0, 1], - [1, 0, 0, 1, 0], - [0, 1, 1, 0, 0], - ] - ) - self.assertEqual(get_peaks_sum(field), 11) - - def test_get_peaks_zeros(self) -> None: - field = np.zeros((3, 5)) - self.assertEqual(get_peaks_sum(field), 0) - - def test_single_peak(self): - field = np.array( - [ - [0, 0, 0, 0, 0], - [0, 1, 0, 0, 0], - [0, 0, 0, 0, 0], - ] + def setUp(self) -> None: + self.fields: tuple[np.ndarray] = ( + np.array( + [ + [0, 1, 0, 0, 1], + [1, 0, 0, 1, 0], + [0, 1, 1, 0, 0], + ] + ), + np.zeros((3, 5)), + np.array( + [ + [0, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ), ) - self.assertEqual(get_peaks_sum(field), 2) + def test_get_peaks_sum(self) -> None: + answers: tuple[int] = (11, 0, 2) + for field, answer in zip(self.fields, answers): + self.assertEqual(get_peaks_sum(field), answer) + def test_get_row_transistions(self): answers = (8, 0, 2) for field, answer in zip(self.fields, answers):