feat(ai): add checkpoint path as argument on launch

This commit is contained in:
Kristofers Solo 2024-01-05 16:34:26 +02:00
parent e6cef45d19
commit 0e5e502898
4 changed files with 14 additions and 11 deletions

View File

View File

@ -32,9 +32,8 @@ group.add_argument(
parser.add_argument(
"-t",
"--train",
type=pos_int,
nargs=2,
metavar=("generations", "parallels"),
nargs=3,
metavar=("n generations", "n parallels", "checkpoint"),
help="Trains the AI",
)

View File

@ -53,7 +53,7 @@ def _calc_holes(field: np.ndarray) -> float:
return penalty
def _calc_height_penalty(field: np.ndarray) -> float:
def _calc_height_penalty(field: np.ndarray) -> tuple[float, float]:
column_heights = np.max(
np.where(field == 1, field.shape[0] - np.arange(field.shape[0])[:, None], 0),
axis=0,

View File

@ -1,4 +1,5 @@
import time
from typing import Optional
import neat
import pygame
@ -12,7 +13,9 @@ from .visualize import plot_progress, plot_species, plot_stats
def train(
gen_count: int = CONFIG.ai.generations, parallel: int = CONFIG.ai.parallels
gen_count: int = CONFIG.ai.generations,
parallel: int = CONFIG.ai.parallels,
checkpoint_path: Optional[str] = None,
) -> None:
"""
Train the AI
@ -22,18 +25,19 @@ def train(
"""
config = get_config()
# population = neat.Checkpointer().restore_checkpoint(
# CONFIG.ai.checkpoint_path / "neat-checkpoint-199"
# )
population = neat.Population(config)
population = (
neat.Checkpointer().restore_checkpoint(checkpoint_path)
if checkpoint_path
else neat.Population(config)
)
population.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
population.add_reporter(stats)
population.add_reporter(neat.Checkpointer(5, 900))
pe = neat.ParallelEvaluator(parallel, eval_genome)
pe = neat.ParallelEvaluator(int(parallel), eval_genome)
winner = population.run(pe.evaluate, gen_count)
winner = population.run(pe.evaluate, int(gen_count))
plot_stats(
stats,
ylog=False,