From 1b2158cb2517f5ea5a88a3310c12b77de6044723 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Sat, 6 Jan 2024 18:01:07 +0200 Subject: [PATCH] feat(ai): add complete lines calculation --- src/ai/moves/height.py | 2 +- src/ai/moves/lines.py | 16 ++++++++++++++++ tests/ai/test_fitness.py | 8 ++++---- tests/ai/test_moves.py | 11 ++++++++--- 4 files changed, 29 insertions(+), 8 deletions(-) create mode 100644 src/ai/moves/lines.py diff --git a/src/ai/moves/height.py b/src/ai/moves/height.py index 5c8d7f8..cb331c4 100644 --- a/src/ai/moves/height.py +++ b/src/ai/moves/height.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import numpy as np diff --git a/src/ai/moves/lines.py b/src/ai/moves/lines.py new file mode 100644 index 0000000..324d7ad --- /dev/null +++ b/src/ai/moves/lines.py @@ -0,0 +1,16 @@ +from typing import Any + +import numpy as np + + +def complete_lines(field: np.ndarray[int, Any]) -> int: + """ + Calculates the number of complete lines in the field. + + Args: + field: 2D array representing the game field. + + Returns: + The number of complete lines in the field. + """ + return np.sum(np.all(field, axis=1)) diff --git a/tests/ai/test_fitness.py b/tests/ai/test_fitness.py index ec03d8f..91ad4eb 100644 --- a/tests/ai/test_fitness.py +++ b/tests/ai/test_fitness.py @@ -36,12 +36,12 @@ class TestFitness(unittest.TestCase): for field, answer in zip(self.fields, answers): self.assertEqual(get_peaks_sum(field=field), answer) - def test_get_row_transistions(self): + def test_get_row_transistions(self) -> None: answers = (8, 0, 2) for field, answer in zip(self.fields, answers): self.assertEqual(get_row_transition(field), answer) - def test_get_col_transistions(self): + def test_get_col_transistions(self) -> None: answers = (5, 0, 1) for field, answer in zip(self.fields, answers): self.assertEqual(get_col_transition(field), answer) @@ -51,7 +51,7 @@ class TestFitness(unittest.TestCase): for field, answer in zip(self.fields, answers): self.assertEqual(get_bumpiness(field=field), answer) - def test_get_holes(self): + def test_get_holes(self) -> None: answers = ( np.array([1, 1, 0, 1, 2]), np.array([0, 0, 0, 0, 0]), @@ -60,7 +60,7 @@ class TestFitness(unittest.TestCase): for field, answer in zip(self.fields, answers): self.assertTrue(np.array_equal(get_holes(field), answer)) - def test_get_wells(self): + def test_get_wells(self) -> None: answers = ( np.array([1, 0, 2, 1, 0]), np.array([0, 0, 0, 0, 0]), diff --git a/tests/ai/test_moves.py b/tests/ai/test_moves.py index c824b8c..e8194f1 100644 --- a/tests/ai/test_moves.py +++ b/tests/ai/test_moves.py @@ -2,11 +2,12 @@ import unittest import numpy as np from ai.moves.height import aggregate_height +from ai.moves.lines import complete_lines class TestFitness(unittest.TestCase): - def test_aggregate_height(self): - field = np.array( + def setUp(self) -> None: + self.field = np.array( [ [0, 0, 0, 0, 1, 1, 0, 0, 0, 0], [0, 1, 1, 1, 1, 1, 1, 0, 0, 1], @@ -17,4 +18,8 @@ class TestFitness(unittest.TestCase): ] ) - self.assertEqual(aggregate_height(field), 48) + def test_aggregate_height(self) -> None: + self.assertEqual(aggregate_height(self.field), 48) + + def test_complete_lines(self) -> None: + self.assertEqual(complete_lines(self.field), 2)