From 41cef03f5040531ff2f9f3abff84202e8fbcb468 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Sat, 6 Jan 2024 17:56:02 +0200 Subject: [PATCH] feat(ai): add aggregate height calculation --- src/ai/moves/height.py | 16 ++++++++++++++++ tests/ai/test_moves.py | 20 ++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 src/ai/moves/height.py create mode 100644 tests/ai/test_moves.py diff --git a/src/ai/moves/height.py b/src/ai/moves/height.py new file mode 100644 index 0000000..5c8d7f8 --- /dev/null +++ b/src/ai/moves/height.py @@ -0,0 +1,16 @@ +from typing import Any, Optional + +import numpy as np + + +def aggregate_height(field: np.ndarray[int, Any]) -> int: + """ + Calculates the aggregate height of the field. + + Args: + field: 2D array representing the game field. + + Returns: + The aggregate height of the field. + """ + return np.sum(field.shape[0] - np.argmax(field, axis=0)) diff --git a/tests/ai/test_moves.py b/tests/ai/test_moves.py new file mode 100644 index 0000000..c824b8c --- /dev/null +++ b/tests/ai/test_moves.py @@ -0,0 +1,20 @@ +import unittest + +import numpy as np +from ai.moves.height import aggregate_height + + +class TestFitness(unittest.TestCase): + def test_aggregate_height(self): + field = np.array( + [ + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 1, 1, 0, 0, 1], + [0, 1, 1, 0, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 0, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ] + ) + + self.assertEqual(aggregate_height(field), 48)