diff --git a/src/ai/fitness/bumpiness.py b/src/ai/fitness/bumpiness.py new file mode 100644 index 0000000..e735a0a --- /dev/null +++ b/src/ai/fitness/bumpiness.py @@ -0,0 +1,17 @@ +from typing import Optional + +import numpy as np + +from .peaks import get_peaks + + +def get_bumpiness( + peaks: Optional[np.ndarray], field: Optional[np.ndarray] = None +) -> int: + if peaks is None and field is None: + raise ValueError("peaks and field cannot both be None") + elif peaks is None: + peaks = get_peaks(field) + + differences = np.abs(np.diff(peaks)) + return int(np.sum(differences)) diff --git a/src/ai/fitness/fitness.py b/src/ai/fitness/fitness.py index 23d9a15..d45a5f4 100644 --- a/src/ai/fitness/fitness.py +++ b/src/ai/fitness/fitness.py @@ -13,10 +13,6 @@ def calculate_fitness(game: Game) -> float: return fitness -def get_bumpiness(field: np.ndarray) -> int: - pass - - def get_holes(field: np.ndarray) -> int: pass diff --git a/tests/ai/test_fitness.py b/tests/ai/test_fitness.py index 98f90a2..9c53c1b 100644 --- a/tests/ai/test_fitness.py +++ b/tests/ai/test_fitness.py @@ -1,10 +1,10 @@ import unittest import numpy as np +from ai.fitness.bumpiness import get_bumpiness from ai.fitness.peaks import get_peaks_sum from ai.fitness.transitions import ( get_col_transition, - get_col_transitions2, get_row_transition, ) @@ -39,7 +39,12 @@ class TestFitness(unittest.TestCase): for field, answer in zip(self.fields, answers): self.assertEqual(get_row_transition(field), answer) - def test_get_col_transistions2(self): + def test_get_col_transistions(self): answers = (5, 0, 1) for field, answer in zip(self.fields, answers): self.assertEqual(get_col_transition(field), answer) + + def test_get_bumpiness(self): + answers = (8, 0, 2) + for field, answer in zip(self.fields, answers): + self.assertEqual(get_bumpiness(field), answer)