feat(utils, ai): set AI paths in Config

This commit is contained in:
Kristofers Solo 2024-01-05 16:21:59 +02:00
parent 845f2bd024
commit 4e46047243
3 changed files with 21 additions and 11 deletions

View File

@ -2,14 +2,14 @@ import pickle
from pathlib import Path
import neat
from utils import BASE_PATH
from utils import CONFIG
def load_genome() -> neat.DefaultGenome:
with open(BASE_PATH / "winner.pkl", "rb") as f:
with open(CONFIG.ai.winner_path, "rb") as f:
return pickle.load(f)
def save_genome(genome: neat.DefaultGenome) -> None:
with open(BASE_PATH / "winner.pkl", "wb") as f:
with open(CONFIG.ai.winner_path, "wb") as f:
pickle.dump(genome, f)

View File

@ -3,7 +3,7 @@ import time
import neat
import pygame
from game import Main
from utils import BASE_PATH
from utils import BASE_PATH, CONFIG
from .config import get_config
from .evaluations import eval_genome
@ -12,19 +12,19 @@ from .log import log
from .visualize import plot_progress, plot_species, plot_stats
def train(gen_count: int, parallel: int = 1) -> None:
def train(
gen_count: int = CONFIG.ai.generations, parallel: int = CONFIG.ai.parallels
) -> None:
"""
Train the AI
Args:
gen_count: Number of generations to train.
gen_count: Number of generations to train (default is 200).
threads: Number of threads to use (default is 1).
"""
config = get_config()
chekpoint_path = BASE_PATH / "checkpoints"
plots_path = BASE_PATH / "plots"
# population = neat.Checkpointer().restore_checkpoint(
# BASE_PATH / "checkpoints" / "neat-checkpoint-199"
# CONFIG.ai.checkpoint_path / "neat-checkpoint-199"
# )
population = neat.Population(config)
population.add_reporter(neat.StdOutReporter(True))
@ -39,9 +39,9 @@ def train(gen_count: int, parallel: int = 1) -> None:
stats,
ylog=False,
view=False,
filename=plots_path / "avg_fitness.png",
filename=CONFIG.ai.plot_path / "avg_fitness.png",
)
plot_species(stats, view=False, filename=plots_path / "speciation.png")
plot_species(stats, view=False, filename=CONFIG.ai.plot_path / "speciation.png")
log.info("Saving best genome")
save_genome(winner)

View File

@ -58,6 +58,15 @@ class Music:
volume: float = 0.01
@define
class AI:
generations: int = 200
parallels: int = 1
winner_path: Path = BASE_PATH / "winner"
plot_path: Path = BASE_PATH / "plots"
checkpoint_path: Path = BASE_PATH / "checkpoints"
@define
class Config:
log_level: str = "warning"
@ -68,6 +77,7 @@ class Config:
font: Font = Font()
music: Music = Music()
colors = TokyoNightNight()
ai = AI()
fps: int = 60