From 1ce17c43ba7204fd1be95044a8065b34341eeb3a Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Fri, 5 Jan 2024 17:19:50 +0200 Subject: [PATCH] refactor(ai): add max and sum peaks refactor(ai): mean -> sum --- src/ai/fitness/peaks.py | 12 ++++++++++-- tests/ai/test_fitness.py | 8 ++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/ai/fitness/peaks.py b/src/ai/fitness/peaks.py index f67193e..f05684c 100644 --- a/src/ai/fitness/peaks.py +++ b/src/ai/fitness/peaks.py @@ -3,7 +3,7 @@ import numpy as np from ai.log import log -def get_peaks(field: np.ndarray) -> float: +def get_peaks(field: np.ndarray) -> np.ndarray: col_num = field.shape[1] peaks = np.zeros(col_num) @@ -11,4 +11,12 @@ def get_peaks(field: np.ndarray) -> float: if 1 in field[:, col]: peaks[col] = field.shape[0] - np.argmax(field[:, col], axis=0) - return float(np.sum(peaks)) + return peaks + + +def get_peaks_max(field: np.ndarray) -> int: + return int(np.max(get_peaks(field))) + + +def get_peaks_sum(field: np.ndarray) -> int: + return np.sum(get_peaks(field)) diff --git a/tests/ai/test_fitness.py b/tests/ai/test_fitness.py index 7241c23..bdfce50 100644 --- a/tests/ai/test_fitness.py +++ b/tests/ai/test_fitness.py @@ -1,7 +1,7 @@ import unittest import numpy as np -from ai.fitness.peaks import get_peaks +from ai.fitness.peaks import get_peaks_sum class TestFitness(unittest.TestCase): @@ -13,11 +13,11 @@ class TestFitness(unittest.TestCase): [0, 1, 1, 0, 0], ] ) - self.assertEqual(get_peaks(field), 11) + self.assertEqual(get_peaks_sum(field), 11) def test_get_peaks_zeros(self) -> None: field = np.zeros((3, 5)) - self.assertEqual(get_peaks(field), 0) + self.assertEqual(get_peaks_sum(field), 0) def test_single_peak(self): field = np.array( @@ -28,4 +28,4 @@ class TestFitness(unittest.TestCase): ] ) - self.assertEqual(get_peaks(field), 2) + self.assertEqual(get_peaks_sum(field), 2)