diff --git a/tests/ai/test_fitness.py b/tests/ai/test_fitness.py index bdfce50..984bfb4 100644 --- a/tests/ai/test_fitness.py +++ b/tests/ai/test_fitness.py @@ -2,6 +2,7 @@ import unittest import numpy as np from ai.fitness.peaks import get_peaks_sum +from ai.fitness.transitions import get_row_transitions class TestFitness(unittest.TestCase): @@ -29,3 +30,7 @@ class TestFitness(unittest.TestCase): ) self.assertEqual(get_peaks_sum(field), 2) + def test_get_row_transistions(self): + answers = (8, 0, 2) + for field, answer in zip(self.fields, answers): + self.assertEqual(get_row_transitions(field), answer)