mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
feat(ai): add checkpoint path as argument on launch
This commit is contained in:
parent
e6cef45d19
commit
0e5e502898
5
main.py
5
main.py
@ -32,9 +32,8 @@ group.add_argument(
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-t",
|
"-t",
|
||||||
"--train",
|
"--train",
|
||||||
type=pos_int,
|
nargs=3,
|
||||||
nargs=2,
|
metavar=("n generations", "n parallels", "checkpoint"),
|
||||||
metavar=("generations", "parallels"),
|
|
||||||
help="Trains the AI",
|
help="Trains the AI",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -53,7 +53,7 @@ def _calc_holes(field: np.ndarray) -> float:
|
|||||||
return penalty
|
return penalty
|
||||||
|
|
||||||
|
|
||||||
def _calc_height_penalty(field: np.ndarray) -> float:
|
def _calc_height_penalty(field: np.ndarray) -> tuple[float, float]:
|
||||||
column_heights = np.max(
|
column_heights = np.max(
|
||||||
np.where(field == 1, field.shape[0] - np.arange(field.shape[0])[:, None], 0),
|
np.where(field == 1, field.shape[0] - np.arange(field.shape[0])[:, None], 0),
|
||||||
axis=0,
|
axis=0,
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import neat
|
import neat
|
||||||
import pygame
|
import pygame
|
||||||
@ -12,7 +13,9 @@ from .visualize import plot_progress, plot_species, plot_stats
|
|||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
gen_count: int = CONFIG.ai.generations, parallel: int = CONFIG.ai.parallels
|
gen_count: int = CONFIG.ai.generations,
|
||||||
|
parallel: int = CONFIG.ai.parallels,
|
||||||
|
checkpoint_path: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Train the AI
|
Train the AI
|
||||||
@ -22,18 +25,19 @@ def train(
|
|||||||
"""
|
"""
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
|
||||||
# population = neat.Checkpointer().restore_checkpoint(
|
population = (
|
||||||
# CONFIG.ai.checkpoint_path / "neat-checkpoint-199"
|
neat.Checkpointer().restore_checkpoint(checkpoint_path)
|
||||||
# )
|
if checkpoint_path
|
||||||
population = neat.Population(config)
|
else neat.Population(config)
|
||||||
|
)
|
||||||
population.add_reporter(neat.StdOutReporter(True))
|
population.add_reporter(neat.StdOutReporter(True))
|
||||||
stats = neat.StatisticsReporter()
|
stats = neat.StatisticsReporter()
|
||||||
population.add_reporter(stats)
|
population.add_reporter(stats)
|
||||||
population.add_reporter(neat.Checkpointer(5, 900))
|
population.add_reporter(neat.Checkpointer(5, 900))
|
||||||
|
|
||||||
pe = neat.ParallelEvaluator(parallel, eval_genome)
|
pe = neat.ParallelEvaluator(int(parallel), eval_genome)
|
||||||
|
|
||||||
winner = population.run(pe.evaluate, gen_count)
|
winner = population.run(pe.evaluate, int(gen_count))
|
||||||
plot_stats(
|
plot_stats(
|
||||||
stats,
|
stats,
|
||||||
ylog=False,
|
ylog=False,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user