feat(utils, ai): set AI paths in Config

This commit is contained in:
Kristofers Solo 2024-01-05 16:21:59 +02:00
parent 845f2bd024
commit 4e46047243
3 changed files with 21 additions and 11 deletions

View File

@ -2,14 +2,14 @@ import pickle
from pathlib import Path from pathlib import Path
import neat import neat
from utils import BASE_PATH from utils import CONFIG
def load_genome() -> neat.DefaultGenome: def load_genome() -> neat.DefaultGenome:
with open(BASE_PATH / "winner.pkl", "rb") as f: with open(CONFIG.ai.winner_path, "rb") as f:
return pickle.load(f) return pickle.load(f)
def save_genome(genome: neat.DefaultGenome) -> None: def save_genome(genome: neat.DefaultGenome) -> None:
with open(BASE_PATH / "winner.pkl", "wb") as f: with open(CONFIG.ai.winner_path, "wb") as f:
pickle.dump(genome, f) pickle.dump(genome, f)

View File

@ -3,7 +3,7 @@ import time
import neat import neat
import pygame import pygame
from game import Main from game import Main
from utils import BASE_PATH from utils import BASE_PATH, CONFIG
from .config import get_config from .config import get_config
from .evaluations import eval_genome from .evaluations import eval_genome
@ -12,19 +12,19 @@ from .log import log
from .visualize import plot_progress, plot_species, plot_stats from .visualize import plot_progress, plot_species, plot_stats
def train(gen_count: int, parallel: int = 1) -> None: def train(
gen_count: int = CONFIG.ai.generations, parallel: int = CONFIG.ai.parallels
) -> None:
""" """
Train the AI Train the AI
Args: Args:
gen_count: Number of generations to train. gen_count: Number of generations to train (default is 200).
threads: Number of threads to use (default is 1). threads: Number of threads to use (default is 1).
""" """
config = get_config() config = get_config()
chekpoint_path = BASE_PATH / "checkpoints"
plots_path = BASE_PATH / "plots"
# population = neat.Checkpointer().restore_checkpoint( # population = neat.Checkpointer().restore_checkpoint(
# BASE_PATH / "checkpoints" / "neat-checkpoint-199" # CONFIG.ai.checkpoint_path / "neat-checkpoint-199"
# ) # )
population = neat.Population(config) population = neat.Population(config)
population.add_reporter(neat.StdOutReporter(True)) population.add_reporter(neat.StdOutReporter(True))
@ -39,9 +39,9 @@ def train(gen_count: int, parallel: int = 1) -> None:
stats, stats,
ylog=False, ylog=False,
view=False, view=False,
filename=plots_path / "avg_fitness.png", filename=CONFIG.ai.plot_path / "avg_fitness.png",
) )
plot_species(stats, view=False, filename=plots_path / "speciation.png") plot_species(stats, view=False, filename=CONFIG.ai.plot_path / "speciation.png")
log.info("Saving best genome") log.info("Saving best genome")
save_genome(winner) save_genome(winner)

View File

@ -58,6 +58,15 @@ class Music:
volume: float = 0.01 volume: float = 0.01
@define
class AI:
generations: int = 200
parallels: int = 1
winner_path: Path = BASE_PATH / "winner"
plot_path: Path = BASE_PATH / "plots"
checkpoint_path: Path = BASE_PATH / "checkpoints"
@define @define
class Config: class Config:
log_level: str = "warning" log_level: str = "warning"
@ -68,6 +77,7 @@ class Config:
font: Font = Font() font: Font = Font()
music: Music = Music() music: Music = Music()
colors = TokyoNightNight() colors = TokyoNightNight()
ai = AI()
fps: int = 60 fps: int = 60