mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
feat(ai): adjust the fitness calculations
This commit is contained in:
parent
7fd44f2834
commit
cf952c4c12
4
config
4
config
@ -6,9 +6,9 @@ reset_on_extinction = False
|
||||
|
||||
[DefaultGenome]
|
||||
# node activation options
|
||||
activation_default = relu
|
||||
activation_default = identity
|
||||
activation_mutate_rate = 0.0
|
||||
activation_options = relu
|
||||
activation_options = identity
|
||||
|
||||
# node aggregation options
|
||||
aggregation_default = sum
|
||||
|
||||
5
main.py
5
main.py
@ -32,8 +32,9 @@ group.add_argument(
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--train",
|
||||
nargs=3,
|
||||
metavar=("n generations", "n parallels", "checkpoint"),
|
||||
type=pos_int,
|
||||
nargs=2,
|
||||
metavar=("n generations", "n parallels"),
|
||||
help="Trains the AI",
|
||||
)
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@ def eval_genome(genome: neat.DefaultGenome, config: neat.Config) -> float:
|
||||
game = app.game
|
||||
net = neat.nn.FeedForwardNetwork.create(genome, config)
|
||||
genome.fitness = 0
|
||||
moves = 0
|
||||
|
||||
while not game.game_over:
|
||||
current_figure: list[int] = [
|
||||
@ -50,10 +49,8 @@ def eval_genome(genome: neat.DefaultGenome, config: neat.Config) -> float:
|
||||
|
||||
decisions[decision]()
|
||||
app.run_game_loop()
|
||||
moves += 1
|
||||
|
||||
fitness = calculate_fitness(game)
|
||||
genome.fitness = fitness - fitness / moves
|
||||
genome.fitness = calculate_fitness(field)
|
||||
score, lines, level = app.game.score, app.game.lines, app.game.level
|
||||
|
||||
log.debug(f"{genome.fitness=:<+6.6}\t{score=:<6} {lines=:<6} {level=:<6}")
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from .fitness import calculate_fitness
|
||||
from .calculate import calculate_fitness
|
||||
|
||||
__all__ = ["calculate_fitness"]
|
||||
|
||||
41
src/ai/fitness/calculate.py
Normal file
41
src/ai/fitness/calculate.py
Normal file
@ -0,0 +1,41 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai.log import log
|
||||
|
||||
from .bumpiness import get_bumpiness
|
||||
from .holes import get_holes, get_holes_sum
|
||||
from .peaks import get_peaks, get_peaks_max, get_peaks_sum
|
||||
from .transitions import get_col_transition, get_row_transition
|
||||
from .wells import get_wells, get_wells_max
|
||||
|
||||
|
||||
def calculate_fitness(field: np.ndarray) -> float:
|
||||
peaks = get_peaks(field=field)
|
||||
holes = get_holes(field=field)
|
||||
highest_peak = get_peaks_max(peaks=peaks)
|
||||
wells = get_wells(peaks=peaks)
|
||||
|
||||
agg_height = get_peaks_sum(peaks=peaks)
|
||||
n_holes = get_holes_sum(field=field)
|
||||
bumpiness = get_bumpiness(peaks=peaks)
|
||||
num_pits = np.count_nonzero(np.count_nonzero(field, axis=0) == 0)
|
||||
max_wells = get_wells_max(wells=wells)
|
||||
n_cols_with_holes = np.count_nonzero(np.array(holes) > 0)
|
||||
row_transitions = get_row_transition(field=field, highest_peak=highest_peak)
|
||||
col_transitions = get_col_transition(field=field, peaks=peaks)
|
||||
cleared = np.count_nonzero(np.mean(field, axis=1))
|
||||
|
||||
fitness = (
|
||||
agg_height
|
||||
+ n_holes
|
||||
+ bumpiness
|
||||
+ num_pits
|
||||
+ max_wells
|
||||
+ n_cols_with_holes
|
||||
+ row_transitions
|
||||
+ col_transitions
|
||||
+ cleared
|
||||
)
|
||||
return -float(fitness)
|
||||
@ -1,37 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import neat
|
||||
import numpy as np
|
||||
from game import Game
|
||||
from utils import CONFIG
|
||||
|
||||
|
||||
def calculate_fitness(game: Game) -> float:
|
||||
field = np.where(game.field != None, 1, 0)
|
||||
reward, penalty = _calc_height_penalty(field)
|
||||
fitness = game.score * 100 - _calc_holes(field) - penalty + reward
|
||||
return fitness
|
||||
|
||||
|
||||
def _calc_holes(field: np.ndarray) -> float:
|
||||
height, width = field.shape
|
||||
penalty = 0
|
||||
|
||||
for col in range(width):
|
||||
column = field[:, col]
|
||||
holde_indices = np.where(column == 0)[0]
|
||||
|
||||
if len(holde_indices) > 0:
|
||||
highest_hole = holde_indices[0]
|
||||
penalty += np.sum(field[highest_hole:, col]) * (height - highest_hole)
|
||||
return penalty
|
||||
|
||||
|
||||
def _calc_height_penalty(field: np.ndarray) -> tuple[float, float]:
|
||||
column_heights = np.max(
|
||||
np.where(field == 1, field.shape[0] - np.arange(field.shape[0])[:, None], 0),
|
||||
axis=0,
|
||||
)
|
||||
reward = np.mean(1 / (column_heights + 1))
|
||||
penalty = np.mean(column_heights * np.arange(1, field.shape[1] + 1))
|
||||
return reward, penalty
|
||||
@ -2,8 +2,6 @@ from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai.log import log
|
||||
|
||||
|
||||
def get_peaks(field: np.ndarray) -> np.ndarray:
|
||||
peaks = np.where(field == 1, field.shape[0] - np.argmax(field, axis=0), 0)
|
||||
|
||||
@ -16,7 +16,6 @@ def get_wells(
|
||||
peaks = get_peaks(field)
|
||||
|
||||
wells = np.zeros_like(peaks)
|
||||
log.debug(f"{peaks=}")
|
||||
|
||||
first_well = peaks[1] - peaks[0]
|
||||
wells[0] = first_well if first_well > 0 else 0
|
||||
@ -33,5 +32,14 @@ def get_wells(
|
||||
|
||||
wells[idx] = well_l if well_l >= well_r else well_r
|
||||
|
||||
log.debug(f"{wells=}")
|
||||
return wells
|
||||
|
||||
|
||||
def get_wells_max(
|
||||
*, wells: Optional[np.ndarray] = None, field: Optional[np.ndarray] = None
|
||||
) -> int:
|
||||
if wells is None and field is None:
|
||||
raise ValueError("wells and field cannot both be None")
|
||||
elif wells is None:
|
||||
wells = get_wells(field)
|
||||
return int(np.max(wells))
|
||||
|
||||
@ -22,6 +22,7 @@ def train(
|
||||
Args:
|
||||
gen_count: Number of generations to train (default is 200).
|
||||
threads: Number of threads to use (default is 1).
|
||||
checkpoint_path: Path to a checkpoint file to resume training from.
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
@ -34,12 +35,16 @@ def train(
|
||||
stats = neat.StatisticsReporter()
|
||||
population.add_reporter(stats)
|
||||
population.add_reporter(
|
||||
neat.Checkpointer(CONFIG.ai.checkpoint_interval, CONFIG.ai.checkpoint_delay)
|
||||
neat.Checkpointer(
|
||||
CONFIG.ai.checkpoint.generation_interval,
|
||||
CONFIG.ai.checkpoint.time_interval,
|
||||
CONFIG.ai.checkpoint.filename_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
pe = neat.ParallelEvaluator(int(parallel), eval_genome)
|
||||
pe = neat.ParallelEvaluator(parallel, eval_genome)
|
||||
|
||||
winner = population.run(pe.evaluate, int(gen_count))
|
||||
winner = population.run(pe.evaluate, gen_count)
|
||||
plot_stats(
|
||||
stats,
|
||||
ylog=False,
|
||||
|
||||
@ -58,16 +58,21 @@ class Music:
|
||||
volume: float = 0.01
|
||||
|
||||
|
||||
@define
|
||||
class Checkpoint:
|
||||
generation_interval: int = 10
|
||||
time_interval: float = 900
|
||||
filename_prefix: str = str(BASE_PATH / "checkpoints" / "neat-checkpoint-")
|
||||
|
||||
|
||||
@define
|
||||
class AI:
|
||||
generations: int = 200
|
||||
parallels: int = 1
|
||||
winner_path: Path = BASE_PATH / "winner"
|
||||
plot_path: Path = BASE_PATH / "plots"
|
||||
checkpoint_path: Path = BASE_PATH / "checkpoints"
|
||||
config_path: Path = BASE_PATH / "config"
|
||||
checkpoint_interval: int = 10
|
||||
checkpoint_delay: int = 900
|
||||
checkpoint: Checkpoint = Checkpoint()
|
||||
|
||||
|
||||
@define
|
||||
|
||||
Loading…
Reference in New Issue
Block a user