diff --git a/src/ai/heuristics/bumpiness.py b/src/ai/heuristics/bumpiness.py index 3632cc4..1fa73af 100644 --- a/src/ai/heuristics/bumpiness.py +++ b/src/ai/heuristics/bumpiness.py @@ -1,11 +1,13 @@ import numpy as np +from .peaks import get_peaks + def get_bumpiness( field: np.ndarray[int, np.dtype[np.uint8]], ) -> int: """ - Calculate the bumpiness of a given signal based on peaks. + Calculate the bumpiness of a given field based on peaks. Args: field: The game field. @@ -13,4 +15,6 @@ def get_bumpiness( Returns: The bumpiness of the field. """ - return int(np.sum(np.abs(np.diff(field.shape[0] - np.argmax(field, axis=0))))) + field = get_peaks(field) + diff = np.diff(field) + return int(np.sum(np.abs(diff))) diff --git a/src/ai/heuristics/peaks.py b/src/ai/heuristics/peaks.py index 6e52b90..c1f7a9c 100644 --- a/src/ai/heuristics/peaks.py +++ b/src/ai/heuristics/peaks.py @@ -11,7 +11,7 @@ def get_peaks(field: np.ndarray[int, np.dtype[np.uint8]]) -> np.ndarray[int, np. Returns: 2D array representing the peaks of the field. """ - result = np.zeros(field.shape[1], dtype=np.uint8) + result = np.zeros(field.shape[1], dtype=int) for col in range(field.shape[1]): for row in range(field.shape[0]): if field[row, col] != 0: