diff --git a/src/ai/fitness/peaks.py b/src/ai/fitness/peaks.py index f05684c..218945c 100644 --- a/src/ai/fitness/peaks.py +++ b/src/ai/fitness/peaks.py @@ -4,14 +4,8 @@ from ai.log import log def get_peaks(field: np.ndarray) -> np.ndarray: - col_num = field.shape[1] - peaks = np.zeros(col_num) - - for col in range(col_num): - if 1 in field[:, col]: - peaks[col] = field.shape[0] - np.argmax(field[:, col], axis=0) - - return peaks + peaks = np.where(field == 1, field.shape[0] - np.argmax(field, axis=0), 0) + return peaks.max(axis=0) def get_peaks_max(field: np.ndarray) -> int: