mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
feat(utils, ai): set AI paths in Config
This commit is contained in:
parent
845f2bd024
commit
4e46047243
@ -2,14 +2,14 @@ import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import neat
|
||||
from utils import BASE_PATH
|
||||
from utils import CONFIG
|
||||
|
||||
|
||||
def load_genome() -> neat.DefaultGenome:
|
||||
with open(BASE_PATH / "winner.pkl", "rb") as f:
|
||||
with open(CONFIG.ai.winner_path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def save_genome(genome: neat.DefaultGenome) -> None:
|
||||
with open(BASE_PATH / "winner.pkl", "wb") as f:
|
||||
with open(CONFIG.ai.winner_path, "wb") as f:
|
||||
pickle.dump(genome, f)
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
import neat
|
||||
import pygame
|
||||
from game import Main
|
||||
from utils import BASE_PATH
|
||||
from utils import BASE_PATH, CONFIG
|
||||
|
||||
from .config import get_config
|
||||
from .evaluations import eval_genome
|
||||
@ -12,19 +12,19 @@ from .log import log
|
||||
from .visualize import plot_progress, plot_species, plot_stats
|
||||
|
||||
|
||||
def train(gen_count: int, parallel: int = 1) -> None:
|
||||
def train(
|
||||
gen_count: int = CONFIG.ai.generations, parallel: int = CONFIG.ai.parallels
|
||||
) -> None:
|
||||
"""
|
||||
Train the AI
|
||||
Args:
|
||||
gen_count: Number of generations to train.
|
||||
gen_count: Number of generations to train (default is 200).
|
||||
threads: Number of threads to use (default is 1).
|
||||
"""
|
||||
config = get_config()
|
||||
chekpoint_path = BASE_PATH / "checkpoints"
|
||||
plots_path = BASE_PATH / "plots"
|
||||
|
||||
# population = neat.Checkpointer().restore_checkpoint(
|
||||
# BASE_PATH / "checkpoints" / "neat-checkpoint-199"
|
||||
# CONFIG.ai.checkpoint_path / "neat-checkpoint-199"
|
||||
# )
|
||||
population = neat.Population(config)
|
||||
population.add_reporter(neat.StdOutReporter(True))
|
||||
@ -39,9 +39,9 @@ def train(gen_count: int, parallel: int = 1) -> None:
|
||||
stats,
|
||||
ylog=False,
|
||||
view=False,
|
||||
filename=plots_path / "avg_fitness.png",
|
||||
filename=CONFIG.ai.plot_path / "avg_fitness.png",
|
||||
)
|
||||
plot_species(stats, view=False, filename=plots_path / "speciation.png")
|
||||
plot_species(stats, view=False, filename=CONFIG.ai.plot_path / "speciation.png")
|
||||
|
||||
log.info("Saving best genome")
|
||||
save_genome(winner)
|
||||
|
||||
@ -58,6 +58,15 @@ class Music:
|
||||
volume: float = 0.01
|
||||
|
||||
|
||||
@define
|
||||
class AI:
|
||||
generations: int = 200
|
||||
parallels: int = 1
|
||||
winner_path: Path = BASE_PATH / "winner"
|
||||
plot_path: Path = BASE_PATH / "plots"
|
||||
checkpoint_path: Path = BASE_PATH / "checkpoints"
|
||||
|
||||
|
||||
@define
|
||||
class Config:
|
||||
log_level: str = "warning"
|
||||
@ -68,6 +77,7 @@ class Config:
|
||||
font: Font = Font()
|
||||
music: Music = Music()
|
||||
colors = TokyoNightNight()
|
||||
ai = AI()
|
||||
fps: int = 60
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user