diff --git a/src/ai/moves/height.py b/src/ai/moves/height.py index cb331c4..8f87b05 100644 --- a/src/ai/moves/height.py +++ b/src/ai/moves/height.py @@ -1,9 +1,7 @@ -from typing import Any - import numpy as np -def aggregate_height(field: np.ndarray[int, Any]) -> int: +def aggregate_height(field: np.ndarray[int, np.dtype[np.uint8]]) -> int: """ Calculates the aggregate height of the field. @@ -13,4 +11,4 @@ def aggregate_height(field: np.ndarray[int, Any]) -> int: Returns: The aggregate height of the field. """ - return np.sum(field.shape[0] - np.argmax(field, axis=0)) + return int(np.sum(field.shape[0] - np.argmax(field, axis=0))) diff --git a/src/ai/moves/holes.py b/src/ai/moves/holes.py new file mode 100644 index 0000000..1479ed5 --- /dev/null +++ b/src/ai/moves/holes.py @@ -0,0 +1,26 @@ +from typing import Optional + +import numpy as np + + +def holes( + field: np.ndarray[int, np.dtype[np.uint8]], +) -> int: + """ + Calculate the number of holes in each column of the given field. + + Args: + field: The signal field. + peaks: Array containing peak indices. If not provided, it will be computed from the field. + + Returns: + The total number of holes in the field. + """ + + first_nonzero_indices = np.argmax(field != 0, axis=0) + + mask = (field == 0) & ( + np.arange(field.shape[0])[:, np.newaxis] > first_nonzero_indices + ) + + return int(np.sum(mask)) diff --git a/src/ai/moves/lines.py b/src/ai/moves/lines.py index 324d7ad..01d9061 100644 --- a/src/ai/moves/lines.py +++ b/src/ai/moves/lines.py @@ -1,9 +1,7 @@ -from typing import Any - import numpy as np -def complete_lines(field: np.ndarray[int, Any]) -> int: +def complete_lines(field: np.ndarray[int, np.dtype[np.uint8]]) -> int: """ Calculates the number of complete lines in the field. @@ -13,4 +11,4 @@ def complete_lines(field: np.ndarray[int, Any]) -> int: Returns: The number of complete lines in the field. """ - return np.sum(np.all(field, axis=1)) + return int(np.sum(np.all(field, axis=1))) diff --git a/tests/ai/test_moves.py b/tests/ai/test_moves.py index e8194f1..de9c0a6 100644 --- a/tests/ai/test_moves.py +++ b/tests/ai/test_moves.py @@ -2,6 +2,7 @@ import unittest import numpy as np from ai.moves.height import aggregate_height +from ai.moves.holes import holes from ai.moves.lines import complete_lines @@ -23,3 +24,6 @@ class TestFitness(unittest.TestCase): def test_complete_lines(self) -> None: self.assertEqual(complete_lines(self.field), 2) + + def test_holes(self) -> None: + self.assertEqual(holes(self.field), 2)