diff --git a/src/ai/heuristics/height.py b/src/ai/heuristics/height.py index c511e74..6d9433c 100644 --- a/src/ai/heuristics/height.py +++ b/src/ai/heuristics/height.py @@ -1,5 +1,7 @@ import numpy as np +from .peaks import get_peaks + def aggregate_height(field: np.ndarray[int, np.dtype[np.uint8]]) -> int: """ @@ -11,10 +13,5 @@ def aggregate_height(field: np.ndarray[int, np.dtype[np.uint8]]) -> int: Returns: The aggregate height of the field. """ - heights = np.zeros(field.shape[1], dtype=np.uint8) - for col in range(field.shape[1]): - for row in range(field.shape[0]): - if field[row, col] != 0: - heights[col] = field.shape[0] - row - break + heights = get_peaks(field) return int(np.sum(heights)) diff --git a/src/ai/heuristics/peaks.py b/src/ai/heuristics/peaks.py new file mode 100644 index 0000000..6e52b90 --- /dev/null +++ b/src/ai/heuristics/peaks.py @@ -0,0 +1,20 @@ +import numpy as np + + +def get_peaks(field: np.ndarray[int, np.dtype[np.uint8]]) -> np.ndarray[int, np.dtype[np.uint8]]: + """ + Calculate the peaks of a given field. + + Args: + field: 2D array representing the game field. + + Returns: + 2D array representing the peaks of the field. + """ + result = np.zeros(field.shape[1], dtype=np.uint8) + for col in range(field.shape[1]): + for row in range(field.shape[0]): + if field[row, col] != 0: + result[col] = field.shape[0] - row + break + return result