mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
feat(ai): add checkpoint path as argument on launch
This commit is contained in:
parent
e6cef45d19
commit
0e5e502898
5
main.py
5
main.py
@ -32,9 +32,8 @@ group.add_argument(
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--train",
|
||||
type=pos_int,
|
||||
nargs=2,
|
||||
metavar=("generations", "parallels"),
|
||||
nargs=3,
|
||||
metavar=("n generations", "n parallels", "checkpoint"),
|
||||
help="Trains the AI",
|
||||
)
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ def _calc_holes(field: np.ndarray) -> float:
|
||||
return penalty
|
||||
|
||||
|
||||
def _calc_height_penalty(field: np.ndarray) -> float:
|
||||
def _calc_height_penalty(field: np.ndarray) -> tuple[float, float]:
|
||||
column_heights = np.max(
|
||||
np.where(field == 1, field.shape[0] - np.arange(field.shape[0])[:, None], 0),
|
||||
axis=0,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import neat
|
||||
import pygame
|
||||
@ -12,7 +13,9 @@ from .visualize import plot_progress, plot_species, plot_stats
|
||||
|
||||
|
||||
def train(
|
||||
gen_count: int = CONFIG.ai.generations, parallel: int = CONFIG.ai.parallels
|
||||
gen_count: int = CONFIG.ai.generations,
|
||||
parallel: int = CONFIG.ai.parallels,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Train the AI
|
||||
@ -22,18 +25,19 @@ def train(
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
# population = neat.Checkpointer().restore_checkpoint(
|
||||
# CONFIG.ai.checkpoint_path / "neat-checkpoint-199"
|
||||
# )
|
||||
population = neat.Population(config)
|
||||
population = (
|
||||
neat.Checkpointer().restore_checkpoint(checkpoint_path)
|
||||
if checkpoint_path
|
||||
else neat.Population(config)
|
||||
)
|
||||
population.add_reporter(neat.StdOutReporter(True))
|
||||
stats = neat.StatisticsReporter()
|
||||
population.add_reporter(stats)
|
||||
population.add_reporter(neat.Checkpointer(5, 900))
|
||||
|
||||
pe = neat.ParallelEvaluator(parallel, eval_genome)
|
||||
pe = neat.ParallelEvaluator(int(parallel), eval_genome)
|
||||
|
||||
winner = population.run(pe.evaluate, gen_count)
|
||||
winner = population.run(pe.evaluate, int(gen_count))
|
||||
plot_stats(
|
||||
stats,
|
||||
ylog=False,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user