mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
refactor(ai)
refactor(ai) refactor(ai) adjust ai
This commit is contained in:
parent
33554bf0e0
commit
43fb1eb8d2
10
config.txt
10
config.txt
@ -1,7 +1,7 @@
|
|||||||
[NEAT]
|
[NEAT]
|
||||||
fitness_criterion = max
|
fitness_criterion = mean
|
||||||
fitness_threshold = 10
|
fitness_threshold = 100000
|
||||||
pop_size = 50
|
pop_size = 100
|
||||||
reset_on_extinction = False
|
reset_on_extinction = False
|
||||||
|
|
||||||
[DefaultGenome]
|
[DefaultGenome]
|
||||||
@ -44,9 +44,9 @@ node_add_prob = 0.2
|
|||||||
node_delete_prob = 0.2
|
node_delete_prob = 0.2
|
||||||
|
|
||||||
# network parameters
|
# network parameters
|
||||||
num_hidden = 2
|
num_hidden = 5
|
||||||
num_inputs = 200
|
num_inputs = 200
|
||||||
num_outputs = 5
|
num_outputs = 6
|
||||||
|
|
||||||
# node response options
|
# node response options
|
||||||
response_init_mean = 1.0
|
response_init_mean = 1.0
|
||||||
|
|||||||
18
main.py
18
main.py
@ -6,13 +6,10 @@ from utils import BASE_PATH, CONFIG
|
|||||||
|
|
||||||
|
|
||||||
def pos_int(string: str) -> int:
|
def pos_int(string: str) -> int:
|
||||||
try:
|
ivalue = int(string)
|
||||||
value = int(string)
|
if ivalue <= 0:
|
||||||
except ValueError:
|
raise argparse.ArgumentTypeError(f"{ivalue} is not a positive integer")
|
||||||
raise argparse.ArgumentTypeError(f"Expected integer, got {string!r}")
|
return ivalue
|
||||||
if value < 0:
|
|
||||||
raise argparse.ArgumentTypeError(f"Expected non negative number, got {value}")
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Tetris game with AI")
|
parser = argparse.ArgumentParser(description="Tetris game with AI")
|
||||||
@ -36,9 +33,8 @@ parser.add_argument(
|
|||||||
"-t",
|
"-t",
|
||||||
"--train",
|
"--train",
|
||||||
type=pos_int,
|
type=pos_int,
|
||||||
nargs="?",
|
nargs=2,
|
||||||
const=100,
|
metavar=("generations", "parallels"),
|
||||||
metavar="int",
|
|
||||||
help="Trains the AI",
|
help="Trains the AI",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -70,7 +66,7 @@ def main(args: argparse.ArgumentParser) -> None:
|
|||||||
|
|
||||||
if args.train is not None:
|
if args.train is not None:
|
||||||
ai.log.debug("Training the AI")
|
ai.log.debug("Training the AI")
|
||||||
ai.train(args.train)
|
ai.train(*args.train)
|
||||||
else:
|
else:
|
||||||
game.log.debug("Running the game")
|
game.log.debug("Running the game")
|
||||||
game.Main().run()
|
game.Main().run()
|
||||||
|
|||||||
@ -1,33 +0,0 @@
|
|||||||
import neat
|
|
||||||
from game import Main
|
|
||||||
|
|
||||||
from .fitness import calculate_fitness
|
|
||||||
from .log import log
|
|
||||||
|
|
||||||
|
|
||||||
def eval_genomes(genomes, config: neat.Config) -> None:
|
|
||||||
app = Main()
|
|
||||||
app.run()
|
|
||||||
for genome_id, genome in genomes:
|
|
||||||
genome.fitness = calculate_fitness(app)
|
|
||||||
net = neat.nn.FeedForwardNetwork.create(genome, config)
|
|
||||||
while not app.game.game_over():
|
|
||||||
output = net.activate(app.game.field)
|
|
||||||
|
|
||||||
decision = output.index(max(output))
|
|
||||||
|
|
||||||
decisions = {
|
|
||||||
0: app.game.move_left,
|
|
||||||
1: app.game.move_right,
|
|
||||||
2: app.game.rotate,
|
|
||||||
3: app.game.rotate_reverse,
|
|
||||||
4: app.game.drop,
|
|
||||||
}
|
|
||||||
|
|
||||||
decisions[decision]()
|
|
||||||
|
|
||||||
genome.fitness = calculate_fitness(app)
|
|
||||||
log.info(
|
|
||||||
f"{genome_id=}\t{genome.fitness=}\t{app.game.score=}\t{app.game.lines=}\t{app.game.level=}"
|
|
||||||
)
|
|
||||||
app.game.restart()
|
|
||||||
73
src/ai/evaluations.py
Executable file
73
src/ai/evaluations.py
Executable file
@ -0,0 +1,73 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
import neat
|
||||||
|
import numpy as np
|
||||||
|
import pygame
|
||||||
|
from game import Block, Main
|
||||||
|
from utils import CONFIG
|
||||||
|
|
||||||
|
from .fitness import calculate_fitness
|
||||||
|
from .log import log
|
||||||
|
|
||||||
|
|
||||||
|
def eval_genome(genome: neat.DefaultGenome, config: neat.Config) -> float:
|
||||||
|
app = Main()
|
||||||
|
|
||||||
|
app.mute()
|
||||||
|
game = app.game
|
||||||
|
net = neat.nn.FeedForwardNetwork.create(genome, config)
|
||||||
|
genome.fitness = 0
|
||||||
|
moves = 0
|
||||||
|
|
||||||
|
while not game.game_over:
|
||||||
|
current_figure: list[int] = [
|
||||||
|
component for vec in game.tetromino.figure.value.shape for component in vec
|
||||||
|
]
|
||||||
|
|
||||||
|
current_figure_pos: list[int] = [
|
||||||
|
component
|
||||||
|
for block in game.tetromino.blocks
|
||||||
|
for component in (int(block.pos.x), int(block.pos.y))
|
||||||
|
]
|
||||||
|
|
||||||
|
next_figures: list[int] = [
|
||||||
|
component
|
||||||
|
for figure in app.next_figures
|
||||||
|
for pos in figure.value.shape
|
||||||
|
for component in pos
|
||||||
|
]
|
||||||
|
|
||||||
|
field: np.ndarray = np.zeros((CONFIG.game.rows, CONFIG.game.columns), dtype=int)
|
||||||
|
|
||||||
|
block: Block
|
||||||
|
for block in game.sprites:
|
||||||
|
field[int(block.pos.y), int(block.pos.x)] = 1
|
||||||
|
|
||||||
|
output = net.activate(
|
||||||
|
# (*current_figure, *current_figure_pos, *next_figures, *field.flatten())
|
||||||
|
field.flatten()
|
||||||
|
)
|
||||||
|
|
||||||
|
decision = output.index(max(output))
|
||||||
|
|
||||||
|
decisions = {
|
||||||
|
0: game.move_left,
|
||||||
|
1: game.move_right,
|
||||||
|
2: game.move_down,
|
||||||
|
3: game.rotate,
|
||||||
|
4: game.rotate_reverse,
|
||||||
|
5: game.drop,
|
||||||
|
}
|
||||||
|
|
||||||
|
decisions[decision]()
|
||||||
|
app.run_game_loop()
|
||||||
|
moves += 1
|
||||||
|
|
||||||
|
genome.fitness = calculate_fitness(game, field) - moves / 10
|
||||||
|
score, lines, level = app.game.score, app.game.lines, app.game.level
|
||||||
|
|
||||||
|
log.debug(f"{genome.fitness=:<+6.6}\t{score=:<6} {lines=:<6} {level=:<6}")
|
||||||
|
|
||||||
|
game.restart()
|
||||||
|
return genome.fitness
|
||||||
@ -1,5 +1,33 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import neat
|
import neat
|
||||||
|
import numpy as np
|
||||||
|
from game import Game
|
||||||
|
from utils import CONFIG
|
||||||
|
|
||||||
|
from .log import log
|
||||||
|
|
||||||
|
|
||||||
def calculate_fitness(app) -> float | int:
|
def calculate_fitness(game: Game, field: Optional[np.ndarray] = None) -> float:
|
||||||
return app.game.score + app.game.lines * 100 + app.game.level * 1000
|
line_values = _calc_line_values(field)
|
||||||
|
return game.score * 10.0 + line_values
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_line_values(field: Optional[np.ndarray]) -> int:
|
||||||
|
if field is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
line_values = 0
|
||||||
|
for idx, line in enumerate(np.flipud(field), start=1):
|
||||||
|
if idx <= 4:
|
||||||
|
line_values += int(line.sum()) * 5
|
||||||
|
elif idx <= 8:
|
||||||
|
line_values += int(line.sum()) * 3
|
||||||
|
elif idx <= 12:
|
||||||
|
line_values += int(line.sum()) * 0
|
||||||
|
elif idx <= 16:
|
||||||
|
line_values += int(line.sum()) * -5
|
||||||
|
else:
|
||||||
|
line_values += int(line.sum()) * -10
|
||||||
|
|
||||||
|
return line_values
|
||||||
|
|||||||
@ -1,20 +1,47 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
import neat
|
import neat
|
||||||
|
import pygame
|
||||||
|
from game import Main
|
||||||
from utils import BASE_PATH
|
from utils import BASE_PATH
|
||||||
|
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
from .evaluations import eval_genomes
|
from .evaluations import eval_genome
|
||||||
from .io import save_genome
|
from .io import save_genome
|
||||||
from .log import log
|
from .log import log
|
||||||
|
from .visualize import plot_progress, plot_species, plot_stats
|
||||||
|
|
||||||
|
|
||||||
def train(generations: int) -> None:
|
def train(gen_count: int, parallel: int = 1) -> None:
|
||||||
"""Train the AI"""
|
"""
|
||||||
|
Train the AI
|
||||||
|
Args:
|
||||||
|
gen_count: Number of generations to train.
|
||||||
|
threads: Number of threads to use (default is 1).
|
||||||
|
"""
|
||||||
config = get_config()
|
config = get_config()
|
||||||
population = neat.Population(config)
|
chekpoint_path = BASE_PATH / "checkpoints"
|
||||||
|
plots_path = BASE_PATH / "plots"
|
||||||
|
|
||||||
|
population = neat.Checkpointer().restore_checkpoint(
|
||||||
|
BASE_PATH / "checkpoints" / "neat-checkpoint-44"
|
||||||
|
)
|
||||||
|
# population = neat.Population(config)
|
||||||
population.add_reporter(neat.StdOutReporter(True))
|
population.add_reporter(neat.StdOutReporter(True))
|
||||||
population.add_reporter(neat.StatisticsReporter())
|
stats = neat.StatisticsReporter()
|
||||||
|
population.add_reporter(stats)
|
||||||
population.add_reporter(neat.Checkpointer(5, 900))
|
population.add_reporter(neat.Checkpointer(5, 900))
|
||||||
winner = population.run(eval_genomes, generations)
|
|
||||||
|
pe = neat.ParallelEvaluator(parallel, eval_genome)
|
||||||
|
|
||||||
|
winner = population.run(pe.evaluate, gen_count)
|
||||||
|
plot_stats(
|
||||||
|
stats,
|
||||||
|
ylog=False,
|
||||||
|
view=False,
|
||||||
|
filename=plots_path / "avg_fitness.svg",
|
||||||
|
)
|
||||||
|
plot_species(stats, view=False, filename=plots_path / "speciation.svg")
|
||||||
|
|
||||||
log.info("Saving best genome")
|
log.info("Saving best genome")
|
||||||
save_genome(winner)
|
save_genome(winner)
|
||||||
|
|||||||
@ -1,6 +1,20 @@
|
|||||||
|
from .block import Block
|
||||||
from .game import Game
|
from .game import Game
|
||||||
from .log import log
|
from .log import log
|
||||||
from .main import Main
|
from .main import Main
|
||||||
|
from .preview import Preview
|
||||||
from .score import Score
|
from .score import Score
|
||||||
|
from .tetromino import Tetromino
|
||||||
|
from .timer import Timer, Timers
|
||||||
|
|
||||||
__all__ = ["log", "Main", "Game", "Score"]
|
__all__ = [
|
||||||
|
"log",
|
||||||
|
"Main",
|
||||||
|
"Game",
|
||||||
|
"Score",
|
||||||
|
"Block",
|
||||||
|
"Tetromino",
|
||||||
|
"Preview",
|
||||||
|
"Timer",
|
||||||
|
"Timers",
|
||||||
|
]
|
||||||
|
|||||||
@ -35,6 +35,7 @@ class Game:
|
|||||||
level: Current game level.
|
level: Current game level.
|
||||||
score: Current game score.
|
score: Current game score.
|
||||||
lines: Number of lines cleared.
|
lines: Number of lines cleared.
|
||||||
|
game_over: True if the game is over, False otherwise.
|
||||||
landing_sound: Sound effect for landing blocks.
|
landing_sound: Sound effect for landing blocks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -112,10 +113,12 @@ class Game:
|
|||||||
def create_new_tetromino(self) -> None:
|
def create_new_tetromino(self) -> None:
|
||||||
"""Create a new tetromino and perform necessary actions."""
|
"""Create a new tetromino and perform necessary actions."""
|
||||||
self._play_landing_sound()
|
self._play_landing_sound()
|
||||||
if self.game_over():
|
|
||||||
self.restart()
|
|
||||||
|
|
||||||
self._check_finished_rows()
|
self._check_finished_rows()
|
||||||
|
|
||||||
|
self.game_over = self._check_game_over()
|
||||||
|
# if self.game_over:
|
||||||
|
# self.restart()
|
||||||
|
|
||||||
self.tetromino = Tetromino(
|
self.tetromino = Tetromino(
|
||||||
self.sprites,
|
self.sprites,
|
||||||
self.create_new_tetromino,
|
self.create_new_tetromino,
|
||||||
@ -123,7 +126,7 @@ class Game:
|
|||||||
self.get_next_figure(),
|
self.get_next_figure(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def game_over(self) -> bool:
|
def _check_game_over(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the game is over.
|
Check if the game is over.
|
||||||
|
|
||||||
@ -131,16 +134,17 @@ class Game:
|
|||||||
True if the game is over, False otherwise.
|
True if the game is over, False otherwise.
|
||||||
"""
|
"""
|
||||||
for block in self.tetromino.blocks:
|
for block in self.tetromino.blocks:
|
||||||
if block.pos.y < 0:
|
if block.pos.y <= 0:
|
||||||
log.info("Game over!")
|
# log.info("Game over!")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def restart(self) -> None:
|
def restart(self) -> None:
|
||||||
"""Restart the game."""
|
"""Restart the game."""
|
||||||
log.info("Restarting the game")
|
# log.info("Restarting the game")
|
||||||
self._reset_game_state()
|
self._reset_game_state()
|
||||||
self._initialize_field_and_tetromino()
|
self._initialize_field_and_tetromino()
|
||||||
|
self.game_over = False
|
||||||
|
|
||||||
def mute(self) -> None:
|
def mute(self) -> None:
|
||||||
"""Mute the game."""
|
"""Mute the game."""
|
||||||
@ -285,6 +289,7 @@ class Game:
|
|||||||
self.level: int = 1
|
self.level: int = 1
|
||||||
self.score: int = 0
|
self.score: int = 0
|
||||||
self.lines: int = 0
|
self.lines: int = 0
|
||||||
|
self.game_over: bool = False
|
||||||
|
|
||||||
def _initialize_sound(self) -> None:
|
def _initialize_sound(self) -> None:
|
||||||
"""Initialize game sounds."""
|
"""Initialize game sounds."""
|
||||||
@ -368,7 +373,6 @@ class Game:
|
|||||||
|
|
||||||
def _reset_game_state(self) -> None:
|
def _reset_game_state(self) -> None:
|
||||||
"""Reset the game state."""
|
"""Reset the game state."""
|
||||||
log.debug("Resetting game state")
|
|
||||||
self.sprites.empty()
|
self.sprites.empty()
|
||||||
self._initialize_field_and_tetromino()
|
self._initialize_field_and_tetromino()
|
||||||
self._initialize_game_state()
|
self._initialize_game_state()
|
||||||
|
|||||||
@ -26,7 +26,7 @@ class Main:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
log.info("Initializing the game")
|
# log.info("Initializing the game")
|
||||||
self._initialize_pygeme()
|
self._initialize_pygeme()
|
||||||
self._initialize_game_components()
|
self._initialize_game_components()
|
||||||
self._start_background_music()
|
self._start_background_music()
|
||||||
@ -38,7 +38,19 @@ class Main:
|
|||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""Run the main game loop."""
|
"""Run the main game loop."""
|
||||||
while True:
|
while True:
|
||||||
self._run_game_loop()
|
self.run_game_loop()
|
||||||
|
|
||||||
|
def run_game_loop(self) -> None:
|
||||||
|
"""Run a single iteration of the game loop."""
|
||||||
|
self.draw()
|
||||||
|
self.handle_events()
|
||||||
|
|
||||||
|
self.game.run()
|
||||||
|
self.score.run()
|
||||||
|
self.preview.run(self.next_figures)
|
||||||
|
|
||||||
|
pygame.display.update()
|
||||||
|
self.clock.tick(CONFIG.fps)
|
||||||
|
|
||||||
def handle_events(self) -> None:
|
def handle_events(self) -> None:
|
||||||
"""Handle Pygame events."""
|
"""Handle Pygame events."""
|
||||||
@ -103,7 +115,7 @@ class Main:
|
|||||||
|
|
||||||
def _initialize_game_components(self) -> None:
|
def _initialize_game_components(self) -> None:
|
||||||
"""Initialize game-related components."""
|
"""Initialize game-related components."""
|
||||||
self.next_figures = self._generate_next_figures()
|
self.next_figures: list[Figure] = self._generate_next_figures()
|
||||||
|
|
||||||
self.game = Game(self._get_next_figure, self._update_score)
|
self.game = Game(self._get_next_figure, self._update_score)
|
||||||
self.score = Score()
|
self.score = Score()
|
||||||
@ -114,15 +126,3 @@ class Main:
|
|||||||
self.music = pygame.mixer.Sound(CONFIG.music.background)
|
self.music = pygame.mixer.Sound(CONFIG.music.background)
|
||||||
self.music.set_volume(CONFIG.music.volume)
|
self.music.set_volume(CONFIG.music.volume)
|
||||||
self.music.play(-1)
|
self.music.play(-1)
|
||||||
|
|
||||||
def _run_game_loop(self) -> None:
|
|
||||||
"""Run a single iteration of the game loop."""
|
|
||||||
self.draw()
|
|
||||||
self.handle_events()
|
|
||||||
|
|
||||||
self.game.run()
|
|
||||||
self.score.run()
|
|
||||||
self.preview.run(self.next_figures)
|
|
||||||
|
|
||||||
pygame.display.update()
|
|
||||||
self.clock.tick(CONFIG.fps)
|
|
||||||
|
|||||||
@ -151,7 +151,7 @@ class Tetromino:
|
|||||||
"""
|
"""
|
||||||
return all(
|
return all(
|
||||||
0 <= pos.x < CONFIG.game.columns
|
0 <= pos.x < CONFIG.game.columns
|
||||||
and 0 <= pos.y <= CONFIG.game.rows
|
and 0 <= pos.y < CONFIG.game.rows
|
||||||
and not self.field[int(pos.y), int(pos.x)]
|
and not self.field[int(pos.y), int(pos.x)]
|
||||||
for pos in new_positions
|
for pos in new_positions
|
||||||
)
|
)
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class Game:
|
|||||||
size: Size = Size(columns * cell.width, rows * cell.width)
|
size: Size = Size(columns * cell.width, rows * cell.width)
|
||||||
pos: Vec2 = Vec2(padding, padding)
|
pos: Vec2 = Vec2(padding, padding)
|
||||||
offset: Vec2 = Vec2(columns // 2, -1)
|
offset: Vec2 = Vec2(columns // 2, -1)
|
||||||
initial_speed: float | int = 400
|
initial_speed: float | int = 50
|
||||||
movment_delay: int = 200
|
movment_delay: int = 200
|
||||||
rotation_delay: int = 200
|
rotation_delay: int = 200
|
||||||
score: dict[int, int] = {1: 40, 2: 100, 3: 300, 4: 1200}
|
score: dict[int, int] = {1: 40, 2: 100, 3: 300, 4: 1200}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user