feat(ai): add checkpoint path as argument on launch

This commit is contained in:
Kristofers Solo 2024-01-05 16:34:26 +02:00
parent e6cef45d19
commit 0e5e502898
4 changed files with 14 additions and 11 deletions

View File

View File

@ -32,9 +32,8 @@ group.add_argument(
parser.add_argument( parser.add_argument(
"-t", "-t",
"--train", "--train",
type=pos_int, nargs=3,
nargs=2, metavar=("n generations", "n parallels", "checkpoint"),
metavar=("generations", "parallels"),
help="Trains the AI", help="Trains the AI",
) )

View File

@ -53,7 +53,7 @@ def _calc_holes(field: np.ndarray) -> float:
return penalty return penalty
def _calc_height_penalty(field: np.ndarray) -> float: def _calc_height_penalty(field: np.ndarray) -> tuple[float, float]:
column_heights = np.max( column_heights = np.max(
np.where(field == 1, field.shape[0] - np.arange(field.shape[0])[:, None], 0), np.where(field == 1, field.shape[0] - np.arange(field.shape[0])[:, None], 0),
axis=0, axis=0,

View File

@ -1,4 +1,5 @@
import time import time
from typing import Optional
import neat import neat
import pygame import pygame
@ -12,7 +13,9 @@ from .visualize import plot_progress, plot_species, plot_stats
def train( def train(
gen_count: int = CONFIG.ai.generations, parallel: int = CONFIG.ai.parallels gen_count: int = CONFIG.ai.generations,
parallel: int = CONFIG.ai.parallels,
checkpoint_path: Optional[str] = None,
) -> None: ) -> None:
""" """
Train the AI Train the AI
@ -22,18 +25,19 @@ def train(
""" """
config = get_config() config = get_config()
# population = neat.Checkpointer().restore_checkpoint( population = (
# CONFIG.ai.checkpoint_path / "neat-checkpoint-199" neat.Checkpointer().restore_checkpoint(checkpoint_path)
# ) if checkpoint_path
population = neat.Population(config) else neat.Population(config)
)
population.add_reporter(neat.StdOutReporter(True)) population.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter() stats = neat.StatisticsReporter()
population.add_reporter(stats) population.add_reporter(stats)
population.add_reporter(neat.Checkpointer(5, 900)) population.add_reporter(neat.Checkpointer(5, 900))
pe = neat.ParallelEvaluator(parallel, eval_genome) pe = neat.ParallelEvaluator(int(parallel), eval_genome)
winner = population.run(pe.evaluate, gen_count) winner = population.run(pe.evaluate, int(gen_count))
plot_stats( plot_stats(
stats, stats,
ylog=False, ylog=False,