mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
Merge branch 'feature/algorithm' of github.com:kristoferssolo/Tetris into feature/algorithm
This commit is contained in:
commit
3ac36d8a58
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install -r requirements.txt
|
pip install -e .
|
||||||
pip install -r requirements_dev.txt
|
pip install -r requirements_dev.txt
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@ -69,7 +69,7 @@ extend-select = [
|
|||||||
"TID",
|
"TID",
|
||||||
"YTT",
|
"YTT",
|
||||||
]
|
]
|
||||||
ignore = ["E741"]
|
ignore = ["E741", "E711"]
|
||||||
show-fixes = true
|
show-fixes = true
|
||||||
line-length = 120
|
line-length = 120
|
||||||
indent-width = 4
|
indent-width = 4
|
||||||
|
|||||||
3
src/ai/__init__.py
Normal file
3
src/ai/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .ai import run
|
||||||
|
|
||||||
|
__all__ = ["run"]
|
||||||
35
src/ai/ai.py
Normal file
35
src/ai/ai.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from game import Main
|
||||||
|
from loguru import logger
|
||||||
|
from utils import BestMove, GameMode
|
||||||
|
|
||||||
|
from .move import get_best_move
|
||||||
|
|
||||||
|
|
||||||
|
def run() -> None:
|
||||||
|
app = Main(GameMode.AI_TRAINING)
|
||||||
|
app.play()
|
||||||
|
game = app.game
|
||||||
|
|
||||||
|
if not game:
|
||||||
|
return
|
||||||
|
|
||||||
|
tetris = game.tetris
|
||||||
|
|
||||||
|
while True:
|
||||||
|
app.handle_events()
|
||||||
|
app.run_game_loop()
|
||||||
|
|
||||||
|
best_move: BestMove = get_best_move(game.tetris, tetris.tetromino.figure)
|
||||||
|
figure = game.tetris.tetromino.figure
|
||||||
|
logger.warning(f"{figure.name=} {best_move=}")
|
||||||
|
|
||||||
|
for rotation in range(best_move.rotation):
|
||||||
|
tetris.tetromino.rotate()
|
||||||
|
|
||||||
|
for _ in range(abs(best_move.x_axis_offset)):
|
||||||
|
if best_move.x_axis_offset > 0:
|
||||||
|
tetris.move_right()
|
||||||
|
elif best_move.x_axis_offset < 0:
|
||||||
|
tetris.move_left()
|
||||||
|
|
||||||
|
tetris.drop()
|
||||||
6
src/ai/heuristics/__init__.py
Normal file
6
src/ai/heuristics/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from .bumpiness import get_bumpiness
|
||||||
|
from .height import aggregate_height
|
||||||
|
from .holes import count_holes
|
||||||
|
from .lines import complete_lines
|
||||||
|
|
||||||
|
__all__ = ["aggregate_height", "get_bumpiness", "complete_lines", "count_holes"]
|
||||||
20
src/ai/heuristics/bumpiness.py
Normal file
20
src/ai/heuristics/bumpiness.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .peaks import get_peaks
|
||||||
|
|
||||||
|
|
||||||
|
def get_bumpiness(
|
||||||
|
field: np.ndarray[int, np.dtype[np.uint8]],
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the bumpiness of a given field based on peaks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: The game field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The bumpiness of the field.
|
||||||
|
"""
|
||||||
|
field = get_peaks(field)
|
||||||
|
diff = np.diff(field)
|
||||||
|
return int(np.sum(np.abs(diff)))
|
||||||
17
src/ai/heuristics/height.py
Normal file
17
src/ai/heuristics/height.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .peaks import get_peaks
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_height(field: np.ndarray[int, np.dtype[np.uint8]]) -> int:
|
||||||
|
"""
|
||||||
|
Calculates the aggregate height of the field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: 2D array representing the game field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The aggregate height of the field.
|
||||||
|
"""
|
||||||
|
heights = get_peaks(field)
|
||||||
|
return int(np.sum(heights))
|
||||||
29
src/ai/heuristics/holes.py
Normal file
29
src/ai/heuristics/holes.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def count_holes(
|
||||||
|
field: np.ndarray[int, np.dtype[np.uint8]],
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the number of holes in each column of the given field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: The signal field.
|
||||||
|
peaks: Array containing peak indices. If not provided, it will be computed from the field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The total number of holes in the field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_rows, num_cols = field.shape
|
||||||
|
holes_count = 0
|
||||||
|
|
||||||
|
for col in range(num_cols):
|
||||||
|
has_tile_above = False
|
||||||
|
|
||||||
|
for row in range(num_rows):
|
||||||
|
if field[row, col] == 1:
|
||||||
|
has_tile_above = True
|
||||||
|
elif field[row, col] == 0 and has_tile_above:
|
||||||
|
holes_count += 1
|
||||||
|
return holes_count
|
||||||
14
src/ai/heuristics/lines.py
Normal file
14
src/ai/heuristics/lines.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def complete_lines(field: np.ndarray[int, np.dtype[np.uint8]]) -> int:
|
||||||
|
"""
|
||||||
|
Calculates the number of complete lines in the field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: 2D array representing the game field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The number of complete lines in the field.
|
||||||
|
"""
|
||||||
|
return int(np.sum(np.all(field, axis=1)))
|
||||||
20
src/ai/heuristics/peaks.py
Normal file
20
src/ai/heuristics/peaks.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_peaks(field: np.ndarray[int, np.dtype[np.uint8]]) -> np.ndarray[int, np.dtype[np.uint8]]:
|
||||||
|
"""
|
||||||
|
Calculate the peaks of a given field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: 2D array representing the game field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
2D array representing the peaks of the field.
|
||||||
|
"""
|
||||||
|
result = np.zeros(field.shape[1], dtype=int)
|
||||||
|
for col in range(field.shape[1]):
|
||||||
|
for row in range(field.shape[0]):
|
||||||
|
if field[row, col] != 0:
|
||||||
|
result[col] = field.shape[0] - row
|
||||||
|
break
|
||||||
|
return result
|
||||||
63
src/ai/move.py
Normal file
63
src/ai/move.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pygame
|
||||||
|
from game import Tetris
|
||||||
|
from game.sprites import Tetromino
|
||||||
|
from loguru import logger
|
||||||
|
from utils import CONFIG, BestMove, Direction, Figure
|
||||||
|
|
||||||
|
from .score import calculate_score
|
||||||
|
|
||||||
|
NUM_ROTATIONS: dict[Figure, int] = {
|
||||||
|
Figure.I: 2,
|
||||||
|
Figure.O: 1,
|
||||||
|
Figure.T: 4,
|
||||||
|
Figure.S: 2,
|
||||||
|
Figure.Z: 2,
|
||||||
|
Figure.J: 4,
|
||||||
|
Figure.L: 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_best_move(game: Tetris, figure: Figure) -> BestMove:
|
||||||
|
best_move: Optional[BestMove] = None
|
||||||
|
best_score: Optional[float] = None
|
||||||
|
phantom_sprites = pygame.sprite.Group() # type: ignore
|
||||||
|
|
||||||
|
for rotation in range(NUM_ROTATIONS[figure]):
|
||||||
|
for i in range(CONFIG.game.columns):
|
||||||
|
tetermino = Tetromino(phantom_sprites, None, game.field, game.tetromino.figure, True)
|
||||||
|
x_axis_movement: int = 0
|
||||||
|
for _ in range(rotation):
|
||||||
|
tetermino.rotate()
|
||||||
|
|
||||||
|
while tetermino.move_horizontal(Direction.LEFT): # move maximaly to the left
|
||||||
|
x_axis_movement -= 1
|
||||||
|
|
||||||
|
for _ in range(i):
|
||||||
|
if tetermino.move_horizontal(Direction.RIGHT): # slowly move to the right
|
||||||
|
x_axis_movement += 1
|
||||||
|
|
||||||
|
tetermino.drop()
|
||||||
|
|
||||||
|
score: float = calculate_score(game)
|
||||||
|
|
||||||
|
logger.debug(f"{tetermino.figure.name=:3} {score=:6.6f} {best_score=} {rotation=:1} {x_axis_movement=:1}")
|
||||||
|
|
||||||
|
if best_score is None or score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_move = BestMove(rotation, x_axis_movement)
|
||||||
|
|
||||||
|
if not tetermino._are_new_positions_valid(
|
||||||
|
[pygame.Vector2(block.pos.x + 1, block.pos.y) for block in tetermino.blocks]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# logger.debug(f"{field=}")
|
||||||
|
tetermino.kill()
|
||||||
|
|
||||||
|
if not best_move:
|
||||||
|
best_move = BestMove(0, 0)
|
||||||
|
tetermino.kill()
|
||||||
|
phantom_sprites.empty()
|
||||||
|
return best_move
|
||||||
17
src/ai/score.py
Normal file
17
src/ai/score.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import numpy as np
|
||||||
|
from game import Tetris
|
||||||
|
|
||||||
|
from .heuristics import aggregate_height, complete_lines, count_holes, get_bumpiness
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_score(game: Tetris) -> float:
|
||||||
|
field: np.ndarray[int, np.dtype[np.uint8]] = np.where(game.field != None, 1, 0)
|
||||||
|
for block in game.tetromino.blocks:
|
||||||
|
field[int(block.pos.y), int(block.pos.x)] = 1
|
||||||
|
|
||||||
|
height = aggregate_height(field) * -0.510066
|
||||||
|
lines = complete_lines(field) * 0.760666
|
||||||
|
holes = count_holes(field) * -0.35663
|
||||||
|
bumpiness = get_bumpiness(field) * -0.184483
|
||||||
|
|
||||||
|
return height + lines + holes + bumpiness
|
||||||
@ -3,7 +3,7 @@ from .enum import Direction, GameMode, Rotation
|
|||||||
from .figure import Figure
|
from .figure import Figure
|
||||||
from .path import BASE_PATH
|
from .path import BASE_PATH
|
||||||
from .settings import read_settings, save_settings
|
from .settings import read_settings, save_settings
|
||||||
from .tuples import Size
|
from .tuples import BestMove, Size
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BASE_PATH",
|
"BASE_PATH",
|
||||||
@ -15,4 +15,5 @@ __all__ = [
|
|||||||
"GameMode",
|
"GameMode",
|
||||||
"read_settings",
|
"read_settings",
|
||||||
"save_settings",
|
"save_settings",
|
||||||
|
"BestMove",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -17,3 +17,16 @@ class Size(NamedTuple):
|
|||||||
if isinstance(other, Size):
|
if isinstance(other, Size):
|
||||||
return Size(self.width - other.width, self.height - other.height)
|
return Size(self.width - other.width, self.height - other.height)
|
||||||
return Size(self.width - other, self.height - other)
|
return Size(self.width - other, self.height - other)
|
||||||
|
|
||||||
|
|
||||||
|
class BestMove(NamedTuple):
|
||||||
|
"""
|
||||||
|
A best move object.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
rotation: The rotation of the best move.
|
||||||
|
x_axis_offset: The x-axis offset of the best move.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rotation: int
|
||||||
|
x_axis_offset: int
|
||||||
|
|||||||
65
tests/ai/test_heuristics.py
Normal file
65
tests/ai/test_heuristics.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from ai.heuristics import aggregate_height, complete_lines, count_holes, get_bumpiness
|
||||||
|
|
||||||
|
|
||||||
|
class TestHeuristics(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.field = np.array(
|
||||||
|
[
|
||||||
|
[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
||||||
|
[0, 1, 1, 1, 1, 1, 1, 0, 0, 1],
|
||||||
|
[0, 1, 1, 0, 1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 0, 1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.field2 = np.array(
|
||||||
|
[
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
|
||||||
|
[1, 0, 0, 0, 0, 0, 0, 0, 0, 1],
|
||||||
|
[1, 0, 0, 1, 1, 0, 0, 0, 0, 1],
|
||||||
|
[1, 1, 0, 1, 1, 0, 0, 0, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.field3 = np.array(
|
||||||
|
[
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[0, 1, 1, 0, 0, 0, 0, 0, 0, 1],
|
||||||
|
[1, 1, 1, 0, 0, 0, 0, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1, 1, 1, 0, 0, 1, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_aggregate_height(self) -> None:
|
||||||
|
self.assertEqual(aggregate_height(self.field), 48)
|
||||||
|
self.assertEqual(aggregate_height(self.field2), 12)
|
||||||
|
self.assertEqual(aggregate_height(self.field3), 30)
|
||||||
|
|
||||||
|
def test_complete_lines(self) -> None:
|
||||||
|
self.assertEqual(complete_lines(self.field), 2)
|
||||||
|
self.assertEqual(complete_lines(self.field2), 0)
|
||||||
|
self.assertEqual(complete_lines(self.field3), 1)
|
||||||
|
|
||||||
|
def test_holes(self) -> None:
|
||||||
|
self.assertEqual(count_holes(self.field), 2)
|
||||||
|
self.assertEqual(count_holes(self.field2), 0)
|
||||||
|
self.assertEqual(count_holes(self.field3), 2)
|
||||||
|
|
||||||
|
def test_bumpiness(self) -> None:
|
||||||
|
self.assertEqual(get_bumpiness(self.field), 6)
|
||||||
|
self.assertEqual(get_bumpiness(self.field2), 11)
|
||||||
|
self.assertEqual(get_bumpiness(self.field3), 7)
|
||||||
@ -1,6 +0,0 @@
|
|||||||
import unittest
|
|
||||||
|
|
||||||
|
|
||||||
class TestBlank(unittest.TestCase):
|
|
||||||
def test(self) -> None:
|
|
||||||
pass
|
|
||||||
@ -27,6 +27,13 @@ parser.add_argument(
|
|||||||
help="Run app with GUI [Default]",
|
help="Run app with GUI [Default]",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-t",
|
||||||
|
"--train",
|
||||||
|
action="store_true",
|
||||||
|
help="Train AI",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(level: str = "warning") -> None:
|
def setup_logger(level: str = "warning") -> None:
|
||||||
from utils import BASE_PATH
|
from utils import BASE_PATH
|
||||||
@ -64,10 +71,14 @@ def main(args) -> None:
|
|||||||
level = "info"
|
level = "info"
|
||||||
else:
|
else:
|
||||||
level = "warning"
|
level = "warning"
|
||||||
|
|
||||||
setup_logger(level)
|
setup_logger(level)
|
||||||
|
|
||||||
run()
|
if args.train: # type: ignore
|
||||||
|
import ai
|
||||||
|
|
||||||
|
ai.run()
|
||||||
|
else:
|
||||||
|
run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user