From f673a9f85087866a0236c56719fedcc5ba703f16 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Fri, 5 Jan 2024 17:50:36 +0200 Subject: [PATCH] feat(ai): add `get_col_transition` --- src/ai/fitness/transitions.py | 20 ++++++++++++++++---- tests/ai/test_fitness.py | 13 +++++++++++-- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/ai/fitness/transitions.py b/src/ai/fitness/transitions.py index ce2af02..89cb0db 100644 --- a/src/ai/fitness/transitions.py +++ b/src/ai/fitness/transitions.py @@ -1,9 +1,9 @@ import numpy as np -from .peaks import get_peaks_max +from .peaks import get_peaks, get_peaks_max -def get_row_transitions(field: np.ndarray) -> int: +def get_row_transition(field: np.ndarray) -> int: highest_peak = get_peaks_max(field) rows_to_check = slice(int(field.shape[0] - highest_peak), field.shape[0]) @@ -12,5 +12,17 @@ def get_row_transitions(field: np.ndarray) -> int: return int(transitions) -def get_col_transitions(field: np.ndarray) -> int: - pass +def get_col_transition(field: np.ndarray) -> int: + transitions_sum = 0 + peaks = get_peaks(field) + + for col in range(field.shape[1]): + if peaks[col] <= 1: + continue + + col_values = field[int(field.shape[0] - peaks[col]) : field.shape[0], col] + transitions = np.sum(col_values[:-1] != col_values[1:]) + + transitions_sum += transitions + + return transitions_sum diff --git a/tests/ai/test_fitness.py b/tests/ai/test_fitness.py index e502b97..98f90a2 100644 --- a/tests/ai/test_fitness.py +++ b/tests/ai/test_fitness.py @@ -2,7 +2,11 @@ import unittest import numpy as np from ai.fitness.peaks import get_peaks_sum -from ai.fitness.transitions import get_row_transitions +from ai.fitness.transitions import ( + get_col_transition, + get_col_transitions2, + get_row_transition, +) class TestFitness(unittest.TestCase): @@ -33,4 +37,9 @@ class TestFitness(unittest.TestCase): def test_get_row_transistions(self): answers = (8, 0, 2) for field, answer in zip(self.fields, answers): - self.assertEqual(get_row_transitions(field), answer) + self.assertEqual(get_row_transition(field), answer) + + def test_get_col_transistions2(self): + answers = (5, 0, 1) + for field, answer in zip(self.fields, answers): + self.assertEqual(get_col_transition(field), answer)