mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
refactor(ai)
refactor(ai) refactor(ai) adjust ai
This commit is contained in:
parent
33554bf0e0
commit
43fb1eb8d2
10
config.txt
10
config.txt
@ -1,7 +1,7 @@
|
||||
[NEAT]
|
||||
fitness_criterion = max
|
||||
fitness_threshold = 10
|
||||
pop_size = 50
|
||||
fitness_criterion = mean
|
||||
fitness_threshold = 100000
|
||||
pop_size = 100
|
||||
reset_on_extinction = False
|
||||
|
||||
[DefaultGenome]
|
||||
@ -44,9 +44,9 @@ node_add_prob = 0.2
|
||||
node_delete_prob = 0.2
|
||||
|
||||
# network parameters
|
||||
num_hidden = 2
|
||||
num_hidden = 5
|
||||
num_inputs = 200
|
||||
num_outputs = 5
|
||||
num_outputs = 6
|
||||
|
||||
# node response options
|
||||
response_init_mean = 1.0
|
||||
|
||||
18
main.py
18
main.py
@ -6,13 +6,10 @@ from utils import BASE_PATH, CONFIG
|
||||
|
||||
|
||||
def pos_int(string: str) -> int:
|
||||
try:
|
||||
value = int(string)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"Expected integer, got {string!r}")
|
||||
if value < 0:
|
||||
raise argparse.ArgumentTypeError(f"Expected non negative number, got {value}")
|
||||
return value
|
||||
ivalue = int(string)
|
||||
if ivalue <= 0:
|
||||
raise argparse.ArgumentTypeError(f"{ivalue} is not a positive integer")
|
||||
return ivalue
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Tetris game with AI")
|
||||
@ -36,9 +33,8 @@ parser.add_argument(
|
||||
"-t",
|
||||
"--train",
|
||||
type=pos_int,
|
||||
nargs="?",
|
||||
const=100,
|
||||
metavar="int",
|
||||
nargs=2,
|
||||
metavar=("generations", "parallels"),
|
||||
help="Trains the AI",
|
||||
)
|
||||
|
||||
@ -70,7 +66,7 @@ def main(args: argparse.ArgumentParser) -> None:
|
||||
|
||||
if args.train is not None:
|
||||
ai.log.debug("Training the AI")
|
||||
ai.train(args.train)
|
||||
ai.train(*args.train)
|
||||
else:
|
||||
game.log.debug("Running the game")
|
||||
game.Main().run()
|
||||
|
||||
@ -1,33 +0,0 @@
|
||||
import neat
|
||||
from game import Main
|
||||
|
||||
from .fitness import calculate_fitness
|
||||
from .log import log
|
||||
|
||||
|
||||
def eval_genomes(genomes, config: neat.Config) -> None:
|
||||
app = Main()
|
||||
app.run()
|
||||
for genome_id, genome in genomes:
|
||||
genome.fitness = calculate_fitness(app)
|
||||
net = neat.nn.FeedForwardNetwork.create(genome, config)
|
||||
while not app.game.game_over():
|
||||
output = net.activate(app.game.field)
|
||||
|
||||
decision = output.index(max(output))
|
||||
|
||||
decisions = {
|
||||
0: app.game.move_left,
|
||||
1: app.game.move_right,
|
||||
2: app.game.rotate,
|
||||
3: app.game.rotate_reverse,
|
||||
4: app.game.drop,
|
||||
}
|
||||
|
||||
decisions[decision]()
|
||||
|
||||
genome.fitness = calculate_fitness(app)
|
||||
log.info(
|
||||
f"{genome_id=}\t{genome.fitness=}\t{app.game.score=}\t{app.game.lines=}\t{app.game.level=}"
|
||||
)
|
||||
app.game.restart()
|
||||
73
src/ai/evaluations.py
Executable file
73
src/ai/evaluations.py
Executable file
@ -0,0 +1,73 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import neat
|
||||
import numpy as np
|
||||
import pygame
|
||||
from game import Block, Main
|
||||
from utils import CONFIG
|
||||
|
||||
from .fitness import calculate_fitness
|
||||
from .log import log
|
||||
|
||||
|
||||
def eval_genome(genome: neat.DefaultGenome, config: neat.Config) -> float:
|
||||
app = Main()
|
||||
|
||||
app.mute()
|
||||
game = app.game
|
||||
net = neat.nn.FeedForwardNetwork.create(genome, config)
|
||||
genome.fitness = 0
|
||||
moves = 0
|
||||
|
||||
while not game.game_over:
|
||||
current_figure: list[int] = [
|
||||
component for vec in game.tetromino.figure.value.shape for component in vec
|
||||
]
|
||||
|
||||
current_figure_pos: list[int] = [
|
||||
component
|
||||
for block in game.tetromino.blocks
|
||||
for component in (int(block.pos.x), int(block.pos.y))
|
||||
]
|
||||
|
||||
next_figures: list[int] = [
|
||||
component
|
||||
for figure in app.next_figures
|
||||
for pos in figure.value.shape
|
||||
for component in pos
|
||||
]
|
||||
|
||||
field: np.ndarray = np.zeros((CONFIG.game.rows, CONFIG.game.columns), dtype=int)
|
||||
|
||||
block: Block
|
||||
for block in game.sprites:
|
||||
field[int(block.pos.y), int(block.pos.x)] = 1
|
||||
|
||||
output = net.activate(
|
||||
# (*current_figure, *current_figure_pos, *next_figures, *field.flatten())
|
||||
field.flatten()
|
||||
)
|
||||
|
||||
decision = output.index(max(output))
|
||||
|
||||
decisions = {
|
||||
0: game.move_left,
|
||||
1: game.move_right,
|
||||
2: game.move_down,
|
||||
3: game.rotate,
|
||||
4: game.rotate_reverse,
|
||||
5: game.drop,
|
||||
}
|
||||
|
||||
decisions[decision]()
|
||||
app.run_game_loop()
|
||||
moves += 1
|
||||
|
||||
genome.fitness = calculate_fitness(game, field) - moves / 10
|
||||
score, lines, level = app.game.score, app.game.lines, app.game.level
|
||||
|
||||
log.debug(f"{genome.fitness=:<+6.6}\t{score=:<6} {lines=:<6} {level=:<6}")
|
||||
|
||||
game.restart()
|
||||
return genome.fitness
|
||||
@ -1,5 +1,33 @@
|
||||
from typing import Optional
|
||||
|
||||
import neat
|
||||
import numpy as np
|
||||
from game import Game
|
||||
from utils import CONFIG
|
||||
|
||||
from .log import log
|
||||
|
||||
|
||||
def calculate_fitness(app) -> float | int:
|
||||
return app.game.score + app.game.lines * 100 + app.game.level * 1000
|
||||
def calculate_fitness(game: Game, field: Optional[np.ndarray] = None) -> float:
|
||||
line_values = _calc_line_values(field)
|
||||
return game.score * 10.0 + line_values
|
||||
|
||||
|
||||
def _calc_line_values(field: Optional[np.ndarray]) -> int:
|
||||
if field is None:
|
||||
return 0
|
||||
|
||||
line_values = 0
|
||||
for idx, line in enumerate(np.flipud(field), start=1):
|
||||
if idx <= 4:
|
||||
line_values += int(line.sum()) * 5
|
||||
elif idx <= 8:
|
||||
line_values += int(line.sum()) * 3
|
||||
elif idx <= 12:
|
||||
line_values += int(line.sum()) * 0
|
||||
elif idx <= 16:
|
||||
line_values += int(line.sum()) * -5
|
||||
else:
|
||||
line_values += int(line.sum()) * -10
|
||||
|
||||
return line_values
|
||||
|
||||
@ -1,20 +1,47 @@
|
||||
import time
|
||||
|
||||
import neat
|
||||
import pygame
|
||||
from game import Main
|
||||
from utils import BASE_PATH
|
||||
|
||||
from .config import get_config
|
||||
from .evaluations import eval_genomes
|
||||
from .evaluations import eval_genome
|
||||
from .io import save_genome
|
||||
from .log import log
|
||||
from .visualize import plot_progress, plot_species, plot_stats
|
||||
|
||||
|
||||
def train(generations: int) -> None:
|
||||
"""Train the AI"""
|
||||
def train(gen_count: int, parallel: int = 1) -> None:
|
||||
"""
|
||||
Train the AI
|
||||
Args:
|
||||
gen_count: Number of generations to train.
|
||||
threads: Number of threads to use (default is 1).
|
||||
"""
|
||||
config = get_config()
|
||||
population = neat.Population(config)
|
||||
chekpoint_path = BASE_PATH / "checkpoints"
|
||||
plots_path = BASE_PATH / "plots"
|
||||
|
||||
population = neat.Checkpointer().restore_checkpoint(
|
||||
BASE_PATH / "checkpoints" / "neat-checkpoint-44"
|
||||
)
|
||||
# population = neat.Population(config)
|
||||
population.add_reporter(neat.StdOutReporter(True))
|
||||
population.add_reporter(neat.StatisticsReporter())
|
||||
stats = neat.StatisticsReporter()
|
||||
population.add_reporter(stats)
|
||||
population.add_reporter(neat.Checkpointer(5, 900))
|
||||
winner = population.run(eval_genomes, generations)
|
||||
|
||||
pe = neat.ParallelEvaluator(parallel, eval_genome)
|
||||
|
||||
winner = population.run(pe.evaluate, gen_count)
|
||||
plot_stats(
|
||||
stats,
|
||||
ylog=False,
|
||||
view=False,
|
||||
filename=plots_path / "avg_fitness.svg",
|
||||
)
|
||||
plot_species(stats, view=False, filename=plots_path / "speciation.svg")
|
||||
|
||||
log.info("Saving best genome")
|
||||
save_genome(winner)
|
||||
|
||||
@ -1,6 +1,20 @@
|
||||
from .block import Block
|
||||
from .game import Game
|
||||
from .log import log
|
||||
from .main import Main
|
||||
from .preview import Preview
|
||||
from .score import Score
|
||||
from .tetromino import Tetromino
|
||||
from .timer import Timer, Timers
|
||||
|
||||
__all__ = ["log", "Main", "Game", "Score"]
|
||||
__all__ = [
|
||||
"log",
|
||||
"Main",
|
||||
"Game",
|
||||
"Score",
|
||||
"Block",
|
||||
"Tetromino",
|
||||
"Preview",
|
||||
"Timer",
|
||||
"Timers",
|
||||
]
|
||||
|
||||
@ -35,6 +35,7 @@ class Game:
|
||||
level: Current game level.
|
||||
score: Current game score.
|
||||
lines: Number of lines cleared.
|
||||
game_over: True if the game is over, False otherwise.
|
||||
landing_sound: Sound effect for landing blocks.
|
||||
"""
|
||||
|
||||
@ -112,10 +113,12 @@ class Game:
|
||||
def create_new_tetromino(self) -> None:
|
||||
"""Create a new tetromino and perform necessary actions."""
|
||||
self._play_landing_sound()
|
||||
if self.game_over():
|
||||
self.restart()
|
||||
|
||||
self._check_finished_rows()
|
||||
|
||||
self.game_over = self._check_game_over()
|
||||
# if self.game_over:
|
||||
# self.restart()
|
||||
|
||||
self.tetromino = Tetromino(
|
||||
self.sprites,
|
||||
self.create_new_tetromino,
|
||||
@ -123,7 +126,7 @@ class Game:
|
||||
self.get_next_figure(),
|
||||
)
|
||||
|
||||
def game_over(self) -> bool:
|
||||
def _check_game_over(self) -> bool:
|
||||
"""
|
||||
Check if the game is over.
|
||||
|
||||
@ -131,16 +134,17 @@ class Game:
|
||||
True if the game is over, False otherwise.
|
||||
"""
|
||||
for block in self.tetromino.blocks:
|
||||
if block.pos.y < 0:
|
||||
log.info("Game over!")
|
||||
if block.pos.y <= 0:
|
||||
# log.info("Game over!")
|
||||
return True
|
||||
return False
|
||||
|
||||
def restart(self) -> None:
|
||||
"""Restart the game."""
|
||||
log.info("Restarting the game")
|
||||
# log.info("Restarting the game")
|
||||
self._reset_game_state()
|
||||
self._initialize_field_and_tetromino()
|
||||
self.game_over = False
|
||||
|
||||
def mute(self) -> None:
|
||||
"""Mute the game."""
|
||||
@ -285,6 +289,7 @@ class Game:
|
||||
self.level: int = 1
|
||||
self.score: int = 0
|
||||
self.lines: int = 0
|
||||
self.game_over: bool = False
|
||||
|
||||
def _initialize_sound(self) -> None:
|
||||
"""Initialize game sounds."""
|
||||
@ -368,7 +373,6 @@ class Game:
|
||||
|
||||
def _reset_game_state(self) -> None:
|
||||
"""Reset the game state."""
|
||||
log.debug("Resetting game state")
|
||||
self.sprites.empty()
|
||||
self._initialize_field_and_tetromino()
|
||||
self._initialize_game_state()
|
||||
|
||||
@ -26,7 +26,7 @@ class Main:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
log.info("Initializing the game")
|
||||
# log.info("Initializing the game")
|
||||
self._initialize_pygeme()
|
||||
self._initialize_game_components()
|
||||
self._start_background_music()
|
||||
@ -38,7 +38,19 @@ class Main:
|
||||
def run(self) -> None:
|
||||
"""Run the main game loop."""
|
||||
while True:
|
||||
self._run_game_loop()
|
||||
self.run_game_loop()
|
||||
|
||||
def run_game_loop(self) -> None:
|
||||
"""Run a single iteration of the game loop."""
|
||||
self.draw()
|
||||
self.handle_events()
|
||||
|
||||
self.game.run()
|
||||
self.score.run()
|
||||
self.preview.run(self.next_figures)
|
||||
|
||||
pygame.display.update()
|
||||
self.clock.tick(CONFIG.fps)
|
||||
|
||||
def handle_events(self) -> None:
|
||||
"""Handle Pygame events."""
|
||||
@ -103,7 +115,7 @@ class Main:
|
||||
|
||||
def _initialize_game_components(self) -> None:
|
||||
"""Initialize game-related components."""
|
||||
self.next_figures = self._generate_next_figures()
|
||||
self.next_figures: list[Figure] = self._generate_next_figures()
|
||||
|
||||
self.game = Game(self._get_next_figure, self._update_score)
|
||||
self.score = Score()
|
||||
@ -114,15 +126,3 @@ class Main:
|
||||
self.music = pygame.mixer.Sound(CONFIG.music.background)
|
||||
self.music.set_volume(CONFIG.music.volume)
|
||||
self.music.play(-1)
|
||||
|
||||
def _run_game_loop(self) -> None:
|
||||
"""Run a single iteration of the game loop."""
|
||||
self.draw()
|
||||
self.handle_events()
|
||||
|
||||
self.game.run()
|
||||
self.score.run()
|
||||
self.preview.run(self.next_figures)
|
||||
|
||||
pygame.display.update()
|
||||
self.clock.tick(CONFIG.fps)
|
||||
|
||||
@ -151,7 +151,7 @@ class Tetromino:
|
||||
"""
|
||||
return all(
|
||||
0 <= pos.x < CONFIG.game.columns
|
||||
and 0 <= pos.y <= CONFIG.game.rows
|
||||
and 0 <= pos.y < CONFIG.game.rows
|
||||
and not self.field[int(pos.y), int(pos.x)]
|
||||
for pos in new_positions
|
||||
)
|
||||
|
||||
@ -21,7 +21,7 @@ class Game:
|
||||
size: Size = Size(columns * cell.width, rows * cell.width)
|
||||
pos: Vec2 = Vec2(padding, padding)
|
||||
offset: Vec2 = Vec2(columns // 2, -1)
|
||||
initial_speed: float | int = 400
|
||||
initial_speed: float | int = 50
|
||||
movment_delay: int = 200
|
||||
rotation_delay: int = 200
|
||||
score: dict[int, int] = {1: 40, 2: 100, 3: 300, 4: 1200}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user