From e9cd973360ba4e3557e8d61abb8501c838f3db52 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Sat, 6 Jan 2024 18:22:08 +0200 Subject: [PATCH] feat(ai): add bumpiness calculation --- src/ai/moves/bumpiness.py | 16 ++++++++++++++++ src/ai/moves/holes.py | 2 -- tests/ai/test_moves.py | 4 ++++ 3 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 src/ai/moves/bumpiness.py diff --git a/src/ai/moves/bumpiness.py b/src/ai/moves/bumpiness.py new file mode 100644 index 0000000..17c57ae --- /dev/null +++ b/src/ai/moves/bumpiness.py @@ -0,0 +1,16 @@ +import numpy as np + + +def bumpiness( + field: np.ndarray[int, np.dtype[np.uint8]], +) -> int: + """ + Calculate the bumpiness of a given signal based on peaks. + + Args: + field: The game field. + + Returns: + The bumpiness of the field. + """ + return int(np.sum(np.abs(np.diff(field.shape[0] - np.argmax(field, axis=0))))) diff --git a/src/ai/moves/holes.py b/src/ai/moves/holes.py index 1479ed5..77ee04c 100644 --- a/src/ai/moves/holes.py +++ b/src/ai/moves/holes.py @@ -1,5 +1,3 @@ -from typing import Optional - import numpy as np diff --git a/tests/ai/test_moves.py b/tests/ai/test_moves.py index de9c0a6..afc98fa 100644 --- a/tests/ai/test_moves.py +++ b/tests/ai/test_moves.py @@ -1,6 +1,7 @@ import unittest import numpy as np +from ai.moves.bumpiness import bumpiness from ai.moves.height import aggregate_height from ai.moves.holes import holes from ai.moves.lines import complete_lines @@ -27,3 +28,6 @@ class TestFitness(unittest.TestCase): def test_holes(self) -> None: self.assertEqual(holes(self.field), 2) + + def test_bumpiness(self) -> None: + self.assertEqual(bumpiness(self.field), 6)