From e84cacca1cc77b7d41b39247ff145cc27bd81e7a Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Fri, 12 Jan 2024 16:33:38 +0200 Subject: [PATCH] feat(ai): add `get_peaks` fix(ai): `aggregate_height` --- src/ai/heuristics/height.py | 9 +++------ src/ai/heuristics/peaks.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) create mode 100644 src/ai/heuristics/peaks.py diff --git a/src/ai/heuristics/height.py b/src/ai/heuristics/height.py index c511e74..6d9433c 100644 --- a/src/ai/heuristics/height.py +++ b/src/ai/heuristics/height.py @@ -1,5 +1,7 @@ import numpy as np +from .peaks import get_peaks + def aggregate_height(field: np.ndarray[int, np.dtype[np.uint8]]) -> int: """ @@ -11,10 +13,5 @@ def aggregate_height(field: np.ndarray[int, np.dtype[np.uint8]]) -> int: Returns: The aggregate height of the field. """ - heights = np.zeros(field.shape[1], dtype=np.uint8) - for col in range(field.shape[1]): - for row in range(field.shape[0]): - if field[row, col] != 0: - heights[col] = field.shape[0] - row - break + heights = get_peaks(field) return int(np.sum(heights)) diff --git a/src/ai/heuristics/peaks.py b/src/ai/heuristics/peaks.py new file mode 100644 index 0000000..6e52b90 --- /dev/null +++ b/src/ai/heuristics/peaks.py @@ -0,0 +1,20 @@ +import numpy as np + + +def get_peaks(field: np.ndarray[int, np.dtype[np.uint8]]) -> np.ndarray[int, np.dtype[np.uint8]]: + """ + Calculate the peaks of a given field. + + Args: + field: 2D array representing the game field. + + Returns: + 2D array representing the peaks of the field. + """ + result = np.zeros(field.shape[1], dtype=np.uint8) + for col in range(field.shape[1]): + for row in range(field.shape[0]): + if field[row, col] != 0: + result[col] = field.shape[0] - row + break + return result