refactor(ai): adjust the fitness calculation

This commit is contained in:
Kristofers Solo 2024-01-05 15:51:35 +02:00
parent 43fb1eb8d2
commit 64f14d178f
5 changed files with 44 additions and 49 deletions

View File

@ -1,6 +1,6 @@
[NEAT]
fitness_criterion = mean
fitness_threshold = 100000
fitness_threshold = 10000
pop_size = 100
reset_on_extinction = False
@ -44,8 +44,8 @@ node_add_prob = 0.2
node_delete_prob = 0.2
# network parameters
num_hidden = 5
num_inputs = 200
num_hidden = 3
num_inputs = 216
num_outputs = 6
# node response options

27
src/ai/evaluations.py Executable file → Normal file
View File

@ -22,32 +22,20 @@ def eval_genome(genome: neat.DefaultGenome, config: neat.Config) -> float:
while not game.game_over:
current_figure: list[int] = [
component for vec in game.tetromino.figure.value.shape for component in vec
]
current_figure_pos: list[int] = [
component
for block in game.tetromino.blocks
for component in (int(block.pos.x), int(block.pos.y))
]
next_figures: list[int] = [
component
for figure in app.next_figures
for pos in figure.value.shape
for component in pos
next_figure: list[int] = [
vec
for vec in app.game.get_next_figure().value.shape
for vec in (int(vec.x), int(vec.y))
]
field: np.ndarray = np.zeros((CONFIG.game.rows, CONFIG.game.columns), dtype=int)
field = np.where(game.field != None, 1, 0)
block: Block
for block in game.sprites:
field[int(block.pos.y), int(block.pos.x)] = 1
output = net.activate(
# (*current_figure, *current_figure_pos, *next_figures, *field.flatten())
field.flatten()
)
output = net.activate((*next_figure, *current_figure, *field.flatten()))
decision = output.index(max(output))
@ -64,7 +52,8 @@ def eval_genome(genome: neat.DefaultGenome, config: neat.Config) -> float:
app.run_game_loop()
moves += 1
genome.fitness = calculate_fitness(game, field) - moves / 10
fitness = calculate_fitness(game)
genome.fitness = fitness - fitness / moves
score, lines, level = app.game.score, app.game.lines, app.game.level
log.debug(f"{genome.fitness=:<+6.6}\t{score=:<6} {lines=:<6} {level=:<6}")

View File

@ -8,26 +8,32 @@ from utils import CONFIG
from .log import log
def calculate_fitness(game: Game, field: Optional[np.ndarray] = None) -> float:
line_values = _calc_line_values(field)
return game.score * 10.0 + line_values
def calculate_fitness(game: Game) -> float:
field = np.where(game.field != None, 1, 0)
reward, penalty = _calc_height_penalty(field)
fitness = game.score * 100 - _calc_holes(field) - penalty + reward
return fitness
def _calc_line_values(field: Optional[np.ndarray]) -> int:
if field is None:
return 0
def _calc_holes(field: np.ndarray) -> float:
height, width = field.shape
penalty = 0
line_values = 0
for idx, line in enumerate(np.flipud(field), start=1):
if idx <= 4:
line_values += int(line.sum()) * 5
elif idx <= 8:
line_values += int(line.sum()) * 3
elif idx <= 12:
line_values += int(line.sum()) * 0
elif idx <= 16:
line_values += int(line.sum()) * -5
else:
line_values += int(line.sum()) * -10
for col in range(width):
column = field[:, col]
holde_indices = np.where(column == 0)[0]
return line_values
if len(holde_indices) > 0:
highest_hole = holde_indices[0]
penalty += np.sum(field[highest_hole:, col]) * (height - highest_hole)
return penalty
def _calc_height_penalty(field: np.ndarray) -> float:
column_heights = np.max(
np.where(field == 1, field.shape[0] - np.arange(field.shape[0])[:, None], 0),
axis=0,
)
reward = np.mean(1 / (column_heights + 1))
penalty = np.mean(column_heights * np.arange(1, field.shape[1] + 1))
return reward, penalty

View File

@ -23,10 +23,10 @@ def train(gen_count: int, parallel: int = 1) -> None:
chekpoint_path = BASE_PATH / "checkpoints"
plots_path = BASE_PATH / "plots"
population = neat.Checkpointer().restore_checkpoint(
BASE_PATH / "checkpoints" / "neat-checkpoint-44"
)
# population = neat.Population(config)
# population = neat.Checkpointer().restore_checkpoint(
# BASE_PATH / "checkpoints" / "neat-checkpoint-199"
# )
population = neat.Population(config)
population.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
population.add_reporter(stats)
@ -39,9 +39,9 @@ def train(gen_count: int, parallel: int = 1) -> None:
stats,
ylog=False,
view=False,
filename=plots_path / "avg_fitness.svg",
filename=plots_path / "avg_fitness.png",
)
plot_species(stats, view=False, filename=plots_path / "speciation.svg")
plot_species(stats, view=False, filename=plots_path / "speciation.png")
log.info("Saving best genome")
save_genome(winner)

View File

@ -101,9 +101,9 @@ class Main:
Returns:
The next figure in the sequence.
"""
next_shape = self.next_figures.pop(0)
next_figure = self.next_figures.pop(0)
self.next_figures.append(Figure.random())
return next_shape
return next_figure
def _initialize_pygeme(self) -> None:
"""Initialize Pygame and set up the display."""