From d458a56d2746352ac8ae5ce08fb135092368bddc Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Sat, 13 Jan 2024 19:49:39 +0200 Subject: [PATCH] fix(ai): `get_bumpiness` --- src/ai/heuristics/bumpiness.py | 8 ++++++-- src/ai/heuristics/peaks.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/ai/heuristics/bumpiness.py b/src/ai/heuristics/bumpiness.py index 3632cc4..1fa73af 100644 --- a/src/ai/heuristics/bumpiness.py +++ b/src/ai/heuristics/bumpiness.py @@ -1,11 +1,13 @@ import numpy as np +from .peaks import get_peaks + def get_bumpiness( field: np.ndarray[int, np.dtype[np.uint8]], ) -> int: """ - Calculate the bumpiness of a given signal based on peaks. + Calculate the bumpiness of a given field based on peaks. Args: field: The game field. @@ -13,4 +15,6 @@ def get_bumpiness( Returns: The bumpiness of the field. """ - return int(np.sum(np.abs(np.diff(field.shape[0] - np.argmax(field, axis=0))))) + field = get_peaks(field) + diff = np.diff(field) + return int(np.sum(np.abs(diff))) diff --git a/src/ai/heuristics/peaks.py b/src/ai/heuristics/peaks.py index 6e52b90..c1f7a9c 100644 --- a/src/ai/heuristics/peaks.py +++ b/src/ai/heuristics/peaks.py @@ -11,7 +11,7 @@ def get_peaks(field: np.ndarray[int, np.dtype[np.uint8]]) -> np.ndarray[int, np. Returns: 2D array representing the peaks of the field. """ - result = np.zeros(field.shape[1], dtype=np.uint8) + result = np.zeros(field.shape[1], dtype=int) for col in range(field.shape[1]): for row in range(field.shape[0]): if field[row, col] != 0: