refactor(ai)

refactor(ai)

refactor(ai)

adjust ai
This commit is contained in:
Kristofers Solo 2024-01-04 19:38:23 +02:00
parent 33554bf0e0
commit 43fb1eb8d2
11 changed files with 192 additions and 83 deletions

View File

@ -1,7 +1,7 @@
[NEAT] [NEAT]
fitness_criterion = max fitness_criterion = mean
fitness_threshold = 10 fitness_threshold = 100000
pop_size = 50 pop_size = 100
reset_on_extinction = False reset_on_extinction = False
[DefaultGenome] [DefaultGenome]
@ -44,9 +44,9 @@ node_add_prob = 0.2
node_delete_prob = 0.2 node_delete_prob = 0.2
# network parameters # network parameters
num_hidden = 2 num_hidden = 5
num_inputs = 200 num_inputs = 200
num_outputs = 5 num_outputs = 6
# node response options # node response options
response_init_mean = 1.0 response_init_mean = 1.0

18
main.py
View File

@ -6,13 +6,10 @@ from utils import BASE_PATH, CONFIG
def pos_int(string: str) -> int: def pos_int(string: str) -> int:
try: ivalue = int(string)
value = int(string) if ivalue <= 0:
except ValueError: raise argparse.ArgumentTypeError(f"{ivalue} is not a positive integer")
raise argparse.ArgumentTypeError(f"Expected integer, got {string!r}") return ivalue
if value < 0:
raise argparse.ArgumentTypeError(f"Expected non negative number, got {value}")
return value
parser = argparse.ArgumentParser(description="Tetris game with AI") parser = argparse.ArgumentParser(description="Tetris game with AI")
@ -36,9 +33,8 @@ parser.add_argument(
"-t", "-t",
"--train", "--train",
type=pos_int, type=pos_int,
nargs="?", nargs=2,
const=100, metavar=("generations", "parallels"),
metavar="int",
help="Trains the AI", help="Trains the AI",
) )
@ -70,7 +66,7 @@ def main(args: argparse.ArgumentParser) -> None:
if args.train is not None: if args.train is not None:
ai.log.debug("Training the AI") ai.log.debug("Training the AI")
ai.train(args.train) ai.train(*args.train)
else: else:
game.log.debug("Running the game") game.log.debug("Running the game")
game.Main().run() game.Main().run()

View File

@ -1,33 +0,0 @@
import neat
from game import Main
from .fitness import calculate_fitness
from .log import log
def eval_genomes(genomes, config: neat.Config) -> None:
app = Main()
app.run()
for genome_id, genome in genomes:
genome.fitness = calculate_fitness(app)
net = neat.nn.FeedForwardNetwork.create(genome, config)
while not app.game.game_over():
output = net.activate(app.game.field)
decision = output.index(max(output))
decisions = {
0: app.game.move_left,
1: app.game.move_right,
2: app.game.rotate,
3: app.game.rotate_reverse,
4: app.game.drop,
}
decisions[decision]()
genome.fitness = calculate_fitness(app)
log.info(
f"{genome_id=}\t{genome.fitness=}\t{app.game.score=}\t{app.game.lines=}\t{app.game.level=}"
)
app.game.restart()

73
src/ai/evaluations.py Executable file
View File

@ -0,0 +1,73 @@
import math
import time
import neat
import numpy as np
import pygame
from game import Block, Main
from utils import CONFIG
from .fitness import calculate_fitness
from .log import log
def eval_genome(genome: neat.DefaultGenome, config: neat.Config) -> float:
app = Main()
app.mute()
game = app.game
net = neat.nn.FeedForwardNetwork.create(genome, config)
genome.fitness = 0
moves = 0
while not game.game_over:
current_figure: list[int] = [
component for vec in game.tetromino.figure.value.shape for component in vec
]
current_figure_pos: list[int] = [
component
for block in game.tetromino.blocks
for component in (int(block.pos.x), int(block.pos.y))
]
next_figures: list[int] = [
component
for figure in app.next_figures
for pos in figure.value.shape
for component in pos
]
field: np.ndarray = np.zeros((CONFIG.game.rows, CONFIG.game.columns), dtype=int)
block: Block
for block in game.sprites:
field[int(block.pos.y), int(block.pos.x)] = 1
output = net.activate(
# (*current_figure, *current_figure_pos, *next_figures, *field.flatten())
field.flatten()
)
decision = output.index(max(output))
decisions = {
0: game.move_left,
1: game.move_right,
2: game.move_down,
3: game.rotate,
4: game.rotate_reverse,
5: game.drop,
}
decisions[decision]()
app.run_game_loop()
moves += 1
genome.fitness = calculate_fitness(game, field) - moves / 10
score, lines, level = app.game.score, app.game.lines, app.game.level
log.debug(f"{genome.fitness=:<+6.6}\t{score=:<6} {lines=:<6} {level=:<6}")
game.restart()
return genome.fitness

View File

@ -1,5 +1,33 @@
from typing import Optional
import neat import neat
import numpy as np
from game import Game
from utils import CONFIG
from .log import log
def calculate_fitness(app) -> float | int: def calculate_fitness(game: Game, field: Optional[np.ndarray] = None) -> float:
return app.game.score + app.game.lines * 100 + app.game.level * 1000 line_values = _calc_line_values(field)
return game.score * 10.0 + line_values
def _calc_line_values(field: Optional[np.ndarray]) -> int:
if field is None:
return 0
line_values = 0
for idx, line in enumerate(np.flipud(field), start=1):
if idx <= 4:
line_values += int(line.sum()) * 5
elif idx <= 8:
line_values += int(line.sum()) * 3
elif idx <= 12:
line_values += int(line.sum()) * 0
elif idx <= 16:
line_values += int(line.sum()) * -5
else:
line_values += int(line.sum()) * -10
return line_values

View File

@ -1,20 +1,47 @@
import time
import neat import neat
import pygame
from game import Main
from utils import BASE_PATH from utils import BASE_PATH
from .config import get_config from .config import get_config
from .evaluations import eval_genomes from .evaluations import eval_genome
from .io import save_genome from .io import save_genome
from .log import log from .log import log
from .visualize import plot_progress, plot_species, plot_stats
def train(generations: int) -> None: def train(gen_count: int, parallel: int = 1) -> None:
"""Train the AI""" """
Train the AI
Args:
gen_count: Number of generations to train.
threads: Number of threads to use (default is 1).
"""
config = get_config() config = get_config()
population = neat.Population(config) chekpoint_path = BASE_PATH / "checkpoints"
plots_path = BASE_PATH / "plots"
population = neat.Checkpointer().restore_checkpoint(
BASE_PATH / "checkpoints" / "neat-checkpoint-44"
)
# population = neat.Population(config)
population.add_reporter(neat.StdOutReporter(True)) population.add_reporter(neat.StdOutReporter(True))
population.add_reporter(neat.StatisticsReporter()) stats = neat.StatisticsReporter()
population.add_reporter(stats)
population.add_reporter(neat.Checkpointer(5, 900)) population.add_reporter(neat.Checkpointer(5, 900))
winner = population.run(eval_genomes, generations)
pe = neat.ParallelEvaluator(parallel, eval_genome)
winner = population.run(pe.evaluate, gen_count)
plot_stats(
stats,
ylog=False,
view=False,
filename=plots_path / "avg_fitness.svg",
)
plot_species(stats, view=False, filename=plots_path / "speciation.svg")
log.info("Saving best genome") log.info("Saving best genome")
save_genome(winner) save_genome(winner)

View File

@ -1,6 +1,20 @@
from .block import Block
from .game import Game from .game import Game
from .log import log from .log import log
from .main import Main from .main import Main
from .preview import Preview
from .score import Score from .score import Score
from .tetromino import Tetromino
from .timer import Timer, Timers
__all__ = ["log", "Main", "Game", "Score"] __all__ = [
"log",
"Main",
"Game",
"Score",
"Block",
"Tetromino",
"Preview",
"Timer",
"Timers",
]

View File

@ -35,6 +35,7 @@ class Game:
level: Current game level. level: Current game level.
score: Current game score. score: Current game score.
lines: Number of lines cleared. lines: Number of lines cleared.
game_over: True if the game is over, False otherwise.
landing_sound: Sound effect for landing blocks. landing_sound: Sound effect for landing blocks.
""" """
@ -112,10 +113,12 @@ class Game:
def create_new_tetromino(self) -> None: def create_new_tetromino(self) -> None:
"""Create a new tetromino and perform necessary actions.""" """Create a new tetromino and perform necessary actions."""
self._play_landing_sound() self._play_landing_sound()
if self.game_over():
self.restart()
self._check_finished_rows() self._check_finished_rows()
self.game_over = self._check_game_over()
# if self.game_over:
# self.restart()
self.tetromino = Tetromino( self.tetromino = Tetromino(
self.sprites, self.sprites,
self.create_new_tetromino, self.create_new_tetromino,
@ -123,7 +126,7 @@ class Game:
self.get_next_figure(), self.get_next_figure(),
) )
def game_over(self) -> bool: def _check_game_over(self) -> bool:
""" """
Check if the game is over. Check if the game is over.
@ -131,16 +134,17 @@ class Game:
True if the game is over, False otherwise. True if the game is over, False otherwise.
""" """
for block in self.tetromino.blocks: for block in self.tetromino.blocks:
if block.pos.y < 0: if block.pos.y <= 0:
log.info("Game over!") # log.info("Game over!")
return True return True
return False return False
def restart(self) -> None: def restart(self) -> None:
"""Restart the game.""" """Restart the game."""
log.info("Restarting the game") # log.info("Restarting the game")
self._reset_game_state() self._reset_game_state()
self._initialize_field_and_tetromino() self._initialize_field_and_tetromino()
self.game_over = False
def mute(self) -> None: def mute(self) -> None:
"""Mute the game.""" """Mute the game."""
@ -285,6 +289,7 @@ class Game:
self.level: int = 1 self.level: int = 1
self.score: int = 0 self.score: int = 0
self.lines: int = 0 self.lines: int = 0
self.game_over: bool = False
def _initialize_sound(self) -> None: def _initialize_sound(self) -> None:
"""Initialize game sounds.""" """Initialize game sounds."""
@ -368,7 +373,6 @@ class Game:
def _reset_game_state(self) -> None: def _reset_game_state(self) -> None:
"""Reset the game state.""" """Reset the game state."""
log.debug("Resetting game state")
self.sprites.empty() self.sprites.empty()
self._initialize_field_and_tetromino() self._initialize_field_and_tetromino()
self._initialize_game_state() self._initialize_game_state()

View File

@ -26,7 +26,7 @@ class Main:
""" """
def __init__(self) -> None: def __init__(self) -> None:
log.info("Initializing the game") # log.info("Initializing the game")
self._initialize_pygeme() self._initialize_pygeme()
self._initialize_game_components() self._initialize_game_components()
self._start_background_music() self._start_background_music()
@ -38,7 +38,19 @@ class Main:
def run(self) -> None: def run(self) -> None:
"""Run the main game loop.""" """Run the main game loop."""
while True: while True:
self._run_game_loop() self.run_game_loop()
def run_game_loop(self) -> None:
"""Run a single iteration of the game loop."""
self.draw()
self.handle_events()
self.game.run()
self.score.run()
self.preview.run(self.next_figures)
pygame.display.update()
self.clock.tick(CONFIG.fps)
def handle_events(self) -> None: def handle_events(self) -> None:
"""Handle Pygame events.""" """Handle Pygame events."""
@ -103,7 +115,7 @@ class Main:
def _initialize_game_components(self) -> None: def _initialize_game_components(self) -> None:
"""Initialize game-related components.""" """Initialize game-related components."""
self.next_figures = self._generate_next_figures() self.next_figures: list[Figure] = self._generate_next_figures()
self.game = Game(self._get_next_figure, self._update_score) self.game = Game(self._get_next_figure, self._update_score)
self.score = Score() self.score = Score()
@ -114,15 +126,3 @@ class Main:
self.music = pygame.mixer.Sound(CONFIG.music.background) self.music = pygame.mixer.Sound(CONFIG.music.background)
self.music.set_volume(CONFIG.music.volume) self.music.set_volume(CONFIG.music.volume)
self.music.play(-1) self.music.play(-1)
def _run_game_loop(self) -> None:
"""Run a single iteration of the game loop."""
self.draw()
self.handle_events()
self.game.run()
self.score.run()
self.preview.run(self.next_figures)
pygame.display.update()
self.clock.tick(CONFIG.fps)

View File

@ -151,7 +151,7 @@ class Tetromino:
""" """
return all( return all(
0 <= pos.x < CONFIG.game.columns 0 <= pos.x < CONFIG.game.columns
and 0 <= pos.y <= CONFIG.game.rows and 0 <= pos.y < CONFIG.game.rows
and not self.field[int(pos.y), int(pos.x)] and not self.field[int(pos.y), int(pos.x)]
for pos in new_positions for pos in new_positions
) )

View File

@ -21,7 +21,7 @@ class Game:
size: Size = Size(columns * cell.width, rows * cell.width) size: Size = Size(columns * cell.width, rows * cell.width)
pos: Vec2 = Vec2(padding, padding) pos: Vec2 = Vec2(padding, padding)
offset: Vec2 = Vec2(columns // 2, -1) offset: Vec2 = Vec2(columns // 2, -1)
initial_speed: float | int = 400 initial_speed: float | int = 50
movment_delay: int = 200 movment_delay: int = 200
rotation_delay: int = 200 rotation_delay: int = 200
score: dict[int, int] = {1: 40, 2: 100, 3: 300, 4: 1200} score: dict[int, int] = {1: 40, 2: 100, 3: 300, 4: 1200}