mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
feat(utils, ai): set AI paths in Config
This commit is contained in:
parent
845f2bd024
commit
4e46047243
@ -2,14 +2,14 @@ import pickle
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import neat
|
import neat
|
||||||
from utils import BASE_PATH
|
from utils import CONFIG
|
||||||
|
|
||||||
|
|
||||||
def load_genome() -> neat.DefaultGenome:
|
def load_genome() -> neat.DefaultGenome:
|
||||||
with open(BASE_PATH / "winner.pkl", "rb") as f:
|
with open(CONFIG.ai.winner_path, "rb") as f:
|
||||||
return pickle.load(f)
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
def save_genome(genome: neat.DefaultGenome) -> None:
|
def save_genome(genome: neat.DefaultGenome) -> None:
|
||||||
with open(BASE_PATH / "winner.pkl", "wb") as f:
|
with open(CONFIG.ai.winner_path, "wb") as f:
|
||||||
pickle.dump(genome, f)
|
pickle.dump(genome, f)
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import time
|
|||||||
import neat
|
import neat
|
||||||
import pygame
|
import pygame
|
||||||
from game import Main
|
from game import Main
|
||||||
from utils import BASE_PATH
|
from utils import BASE_PATH, CONFIG
|
||||||
|
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
from .evaluations import eval_genome
|
from .evaluations import eval_genome
|
||||||
@ -12,19 +12,19 @@ from .log import log
|
|||||||
from .visualize import plot_progress, plot_species, plot_stats
|
from .visualize import plot_progress, plot_species, plot_stats
|
||||||
|
|
||||||
|
|
||||||
def train(gen_count: int, parallel: int = 1) -> None:
|
def train(
|
||||||
|
gen_count: int = CONFIG.ai.generations, parallel: int = CONFIG.ai.parallels
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Train the AI
|
Train the AI
|
||||||
Args:
|
Args:
|
||||||
gen_count: Number of generations to train.
|
gen_count: Number of generations to train (default is 200).
|
||||||
threads: Number of threads to use (default is 1).
|
threads: Number of threads to use (default is 1).
|
||||||
"""
|
"""
|
||||||
config = get_config()
|
config = get_config()
|
||||||
chekpoint_path = BASE_PATH / "checkpoints"
|
|
||||||
plots_path = BASE_PATH / "plots"
|
|
||||||
|
|
||||||
# population = neat.Checkpointer().restore_checkpoint(
|
# population = neat.Checkpointer().restore_checkpoint(
|
||||||
# BASE_PATH / "checkpoints" / "neat-checkpoint-199"
|
# CONFIG.ai.checkpoint_path / "neat-checkpoint-199"
|
||||||
# )
|
# )
|
||||||
population = neat.Population(config)
|
population = neat.Population(config)
|
||||||
population.add_reporter(neat.StdOutReporter(True))
|
population.add_reporter(neat.StdOutReporter(True))
|
||||||
@ -39,9 +39,9 @@ def train(gen_count: int, parallel: int = 1) -> None:
|
|||||||
stats,
|
stats,
|
||||||
ylog=False,
|
ylog=False,
|
||||||
view=False,
|
view=False,
|
||||||
filename=plots_path / "avg_fitness.png",
|
filename=CONFIG.ai.plot_path / "avg_fitness.png",
|
||||||
)
|
)
|
||||||
plot_species(stats, view=False, filename=plots_path / "speciation.png")
|
plot_species(stats, view=False, filename=CONFIG.ai.plot_path / "speciation.png")
|
||||||
|
|
||||||
log.info("Saving best genome")
|
log.info("Saving best genome")
|
||||||
save_genome(winner)
|
save_genome(winner)
|
||||||
|
|||||||
@ -58,6 +58,15 @@ class Music:
|
|||||||
volume: float = 0.01
|
volume: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
@define
|
||||||
|
class AI:
|
||||||
|
generations: int = 200
|
||||||
|
parallels: int = 1
|
||||||
|
winner_path: Path = BASE_PATH / "winner"
|
||||||
|
plot_path: Path = BASE_PATH / "plots"
|
||||||
|
checkpoint_path: Path = BASE_PATH / "checkpoints"
|
||||||
|
|
||||||
|
|
||||||
@define
|
@define
|
||||||
class Config:
|
class Config:
|
||||||
log_level: str = "warning"
|
log_level: str = "warning"
|
||||||
@ -68,6 +77,7 @@ class Config:
|
|||||||
font: Font = Font()
|
font: Font = Font()
|
||||||
music: Music = Music()
|
music: Music = Music()
|
||||||
colors = TokyoNightNight()
|
colors = TokyoNightNight()
|
||||||
|
ai = AI()
|
||||||
fps: int = 60
|
fps: int = 60
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user