From 43fb1eb8d22aa4329171ac98df4b0fb225f78268 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Thu, 4 Jan 2024 19:38:23 +0200 Subject: [PATCH] refactor(ai) refactor(ai) refactor(ai) adjust ai --- config.txt | 10 +++--- main.py | 18 +++++------ src/ai/evaluation.py | 33 ------------------- src/ai/evaluations.py | 73 +++++++++++++++++++++++++++++++++++++++++++ src/ai/fitness.py | 32 +++++++++++++++++-- src/ai/training.py | 39 +++++++++++++++++++---- src/game/__init__.py | 16 +++++++++- src/game/game.py | 20 +++++++----- src/game/main.py | 30 +++++++++--------- src/game/tetromino.py | 2 +- src/utils/config.py | 2 +- 11 files changed, 192 insertions(+), 83 deletions(-) delete mode 100644 src/ai/evaluation.py create mode 100755 src/ai/evaluations.py diff --git a/config.txt b/config.txt index 04657ea..bac5c80 100644 --- a/config.txt +++ b/config.txt @@ -1,7 +1,7 @@ [NEAT] -fitness_criterion = max -fitness_threshold = 10 -pop_size = 50 +fitness_criterion = mean +fitness_threshold = 100000 +pop_size = 100 reset_on_extinction = False [DefaultGenome] @@ -44,9 +44,9 @@ node_add_prob = 0.2 node_delete_prob = 0.2 # network parameters -num_hidden = 2 +num_hidden = 5 num_inputs = 200 -num_outputs = 5 +num_outputs = 6 # node response options response_init_mean = 1.0 diff --git a/main.py b/main.py index 25e9a54..f2d154b 100755 --- a/main.py +++ b/main.py @@ -6,13 +6,10 @@ from utils import BASE_PATH, CONFIG def pos_int(string: str) -> int: - try: - value = int(string) - except ValueError: - raise argparse.ArgumentTypeError(f"Expected integer, got {string!r}") - if value < 0: - raise argparse.ArgumentTypeError(f"Expected non negative number, got {value}") - return value + ivalue = int(string) + if ivalue <= 0: + raise argparse.ArgumentTypeError(f"{ivalue} is not a positive integer") + return ivalue parser = argparse.ArgumentParser(description="Tetris game with AI") @@ -36,9 +33,8 @@ parser.add_argument( "-t", "--train", type=pos_int, - nargs="?", - const=100, - metavar="int", + nargs=2, + metavar=("generations", "parallels"), help="Trains the AI", ) @@ -70,7 +66,7 @@ def main(args: argparse.ArgumentParser) -> None: if args.train is not None: ai.log.debug("Training the AI") - ai.train(args.train) + ai.train(*args.train) else: game.log.debug("Running the game") game.Main().run() diff --git a/src/ai/evaluation.py b/src/ai/evaluation.py deleted file mode 100644 index ae7d0f7..0000000 --- a/src/ai/evaluation.py +++ /dev/null @@ -1,33 +0,0 @@ -import neat -from game import Main - -from .fitness import calculate_fitness -from .log import log - - -def eval_genomes(genomes, config: neat.Config) -> None: - app = Main() - app.run() - for genome_id, genome in genomes: - genome.fitness = calculate_fitness(app) - net = neat.nn.FeedForwardNetwork.create(genome, config) - while not app.game.game_over(): - output = net.activate(app.game.field) - - decision = output.index(max(output)) - - decisions = { - 0: app.game.move_left, - 1: app.game.move_right, - 2: app.game.rotate, - 3: app.game.rotate_reverse, - 4: app.game.drop, - } - - decisions[decision]() - - genome.fitness = calculate_fitness(app) - log.info( - f"{genome_id=}\t{genome.fitness=}\t{app.game.score=}\t{app.game.lines=}\t{app.game.level=}" - ) - app.game.restart() diff --git a/src/ai/evaluations.py b/src/ai/evaluations.py new file mode 100755 index 0000000..2e3e3ff --- /dev/null +++ b/src/ai/evaluations.py @@ -0,0 +1,73 @@ +import math +import time + +import neat +import numpy as np +import pygame +from game import Block, Main +from utils import CONFIG + +from .fitness import calculate_fitness +from .log import log + + +def eval_genome(genome: neat.DefaultGenome, config: neat.Config) -> float: + app = Main() + + app.mute() + game = app.game + net = neat.nn.FeedForwardNetwork.create(genome, config) + genome.fitness = 0 + moves = 0 + + while not game.game_over: + current_figure: list[int] = [ + component for vec in game.tetromino.figure.value.shape for component in vec + ] + + current_figure_pos: list[int] = [ + component + for block in game.tetromino.blocks + for component in (int(block.pos.x), int(block.pos.y)) + ] + + next_figures: list[int] = [ + component + for figure in app.next_figures + for pos in figure.value.shape + for component in pos + ] + + field: np.ndarray = np.zeros((CONFIG.game.rows, CONFIG.game.columns), dtype=int) + + block: Block + for block in game.sprites: + field[int(block.pos.y), int(block.pos.x)] = 1 + + output = net.activate( + # (*current_figure, *current_figure_pos, *next_figures, *field.flatten()) + field.flatten() + ) + + decision = output.index(max(output)) + + decisions = { + 0: game.move_left, + 1: game.move_right, + 2: game.move_down, + 3: game.rotate, + 4: game.rotate_reverse, + 5: game.drop, + } + + decisions[decision]() + app.run_game_loop() + moves += 1 + + genome.fitness = calculate_fitness(game, field) - moves / 10 + score, lines, level = app.game.score, app.game.lines, app.game.level + + log.debug(f"{genome.fitness=:<+6.6}\t{score=:<6} {lines=:<6} {level=:<6}") + + game.restart() + return genome.fitness diff --git a/src/ai/fitness.py b/src/ai/fitness.py index d1bb757..10c9b6a 100644 --- a/src/ai/fitness.py +++ b/src/ai/fitness.py @@ -1,5 +1,33 @@ +from typing import Optional + import neat +import numpy as np +from game import Game +from utils import CONFIG + +from .log import log -def calculate_fitness(app) -> float | int: - return app.game.score + app.game.lines * 100 + app.game.level * 1000 +def calculate_fitness(game: Game, field: Optional[np.ndarray] = None) -> float: + line_values = _calc_line_values(field) + return game.score * 10.0 + line_values + + +def _calc_line_values(field: Optional[np.ndarray]) -> int: + if field is None: + return 0 + + line_values = 0 + for idx, line in enumerate(np.flipud(field), start=1): + if idx <= 4: + line_values += int(line.sum()) * 5 + elif idx <= 8: + line_values += int(line.sum()) * 3 + elif idx <= 12: + line_values += int(line.sum()) * 0 + elif idx <= 16: + line_values += int(line.sum()) * -5 + else: + line_values += int(line.sum()) * -10 + + return line_values diff --git a/src/ai/training.py b/src/ai/training.py index 7abf8d2..f7202ec 100644 --- a/src/ai/training.py +++ b/src/ai/training.py @@ -1,20 +1,47 @@ +import time + import neat +import pygame +from game import Main from utils import BASE_PATH from .config import get_config -from .evaluations import eval_genomes +from .evaluations import eval_genome from .io import save_genome from .log import log +from .visualize import plot_progress, plot_species, plot_stats -def train(generations: int) -> None: - """Train the AI""" +def train(gen_count: int, parallel: int = 1) -> None: + """ + Train the AI + Args: + gen_count: Number of generations to train. + threads: Number of threads to use (default is 1). + """ config = get_config() - population = neat.Population(config) + chekpoint_path = BASE_PATH / "checkpoints" + plots_path = BASE_PATH / "plots" + + population = neat.Checkpointer().restore_checkpoint( + BASE_PATH / "checkpoints" / "neat-checkpoint-44" + ) + # population = neat.Population(config) population.add_reporter(neat.StdOutReporter(True)) - population.add_reporter(neat.StatisticsReporter()) + stats = neat.StatisticsReporter() + population.add_reporter(stats) population.add_reporter(neat.Checkpointer(5, 900)) - winner = population.run(eval_genomes, generations) + + pe = neat.ParallelEvaluator(parallel, eval_genome) + + winner = population.run(pe.evaluate, gen_count) + plot_stats( + stats, + ylog=False, + view=False, + filename=plots_path / "avg_fitness.svg", + ) + plot_species(stats, view=False, filename=plots_path / "speciation.svg") log.info("Saving best genome") save_genome(winner) diff --git a/src/game/__init__.py b/src/game/__init__.py index 6e0f71d..5029dd7 100644 --- a/src/game/__init__.py +++ b/src/game/__init__.py @@ -1,6 +1,20 @@ +from .block import Block from .game import Game from .log import log from .main import Main +from .preview import Preview from .score import Score +from .tetromino import Tetromino +from .timer import Timer, Timers -__all__ = ["log", "Main", "Game", "Score"] +__all__ = [ + "log", + "Main", + "Game", + "Score", + "Block", + "Tetromino", + "Preview", + "Timer", + "Timers", +] diff --git a/src/game/game.py b/src/game/game.py index 1dd2309..10dd6fc 100644 --- a/src/game/game.py +++ b/src/game/game.py @@ -35,6 +35,7 @@ class Game: level: Current game level. score: Current game score. lines: Number of lines cleared. + game_over: True if the game is over, False otherwise. landing_sound: Sound effect for landing blocks. """ @@ -112,10 +113,12 @@ class Game: def create_new_tetromino(self) -> None: """Create a new tetromino and perform necessary actions.""" self._play_landing_sound() - if self.game_over(): - self.restart() - self._check_finished_rows() + + self.game_over = self._check_game_over() + # if self.game_over: + # self.restart() + self.tetromino = Tetromino( self.sprites, self.create_new_tetromino, @@ -123,7 +126,7 @@ class Game: self.get_next_figure(), ) - def game_over(self) -> bool: + def _check_game_over(self) -> bool: """ Check if the game is over. @@ -131,16 +134,17 @@ class Game: True if the game is over, False otherwise. """ for block in self.tetromino.blocks: - if block.pos.y < 0: - log.info("Game over!") + if block.pos.y <= 0: + # log.info("Game over!") return True return False def restart(self) -> None: """Restart the game.""" - log.info("Restarting the game") + # log.info("Restarting the game") self._reset_game_state() self._initialize_field_and_tetromino() + self.game_over = False def mute(self) -> None: """Mute the game.""" @@ -285,6 +289,7 @@ class Game: self.level: int = 1 self.score: int = 0 self.lines: int = 0 + self.game_over: bool = False def _initialize_sound(self) -> None: """Initialize game sounds.""" @@ -368,7 +373,6 @@ class Game: def _reset_game_state(self) -> None: """Reset the game state.""" - log.debug("Resetting game state") self.sprites.empty() self._initialize_field_and_tetromino() self._initialize_game_state() diff --git a/src/game/main.py b/src/game/main.py index 22f44f9..8fba1c0 100644 --- a/src/game/main.py +++ b/src/game/main.py @@ -26,7 +26,7 @@ class Main: """ def __init__(self) -> None: - log.info("Initializing the game") + # log.info("Initializing the game") self._initialize_pygeme() self._initialize_game_components() self._start_background_music() @@ -38,7 +38,19 @@ class Main: def run(self) -> None: """Run the main game loop.""" while True: - self._run_game_loop() + self.run_game_loop() + + def run_game_loop(self) -> None: + """Run a single iteration of the game loop.""" + self.draw() + self.handle_events() + + self.game.run() + self.score.run() + self.preview.run(self.next_figures) + + pygame.display.update() + self.clock.tick(CONFIG.fps) def handle_events(self) -> None: """Handle Pygame events.""" @@ -103,7 +115,7 @@ class Main: def _initialize_game_components(self) -> None: """Initialize game-related components.""" - self.next_figures = self._generate_next_figures() + self.next_figures: list[Figure] = self._generate_next_figures() self.game = Game(self._get_next_figure, self._update_score) self.score = Score() @@ -114,15 +126,3 @@ class Main: self.music = pygame.mixer.Sound(CONFIG.music.background) self.music.set_volume(CONFIG.music.volume) self.music.play(-1) - - def _run_game_loop(self) -> None: - """Run a single iteration of the game loop.""" - self.draw() - self.handle_events() - - self.game.run() - self.score.run() - self.preview.run(self.next_figures) - - pygame.display.update() - self.clock.tick(CONFIG.fps) diff --git a/src/game/tetromino.py b/src/game/tetromino.py index f2cd75a..0de04b4 100644 --- a/src/game/tetromino.py +++ b/src/game/tetromino.py @@ -151,7 +151,7 @@ class Tetromino: """ return all( 0 <= pos.x < CONFIG.game.columns - and 0 <= pos.y <= CONFIG.game.rows + and 0 <= pos.y < CONFIG.game.rows and not self.field[int(pos.y), int(pos.x)] for pos in new_positions ) diff --git a/src/utils/config.py b/src/utils/config.py index 80ad0db..8dc7e5e 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -21,7 +21,7 @@ class Game: size: Size = Size(columns * cell.width, rows * cell.width) pos: Vec2 = Vec2(padding, padding) offset: Vec2 = Vec2(columns // 2, -1) - initial_speed: float | int = 400 + initial_speed: float | int = 50 movment_delay: int = 200 rotation_delay: int = 200 score: dict[int, int] = {1: 40, 2: 100, 3: 300, 4: 1200}