mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
refactor(ai): adjust the fitness calculation
This commit is contained in:
parent
43fb1eb8d2
commit
64f14d178f
@ -1,6 +1,6 @@
|
|||||||
[NEAT]
|
[NEAT]
|
||||||
fitness_criterion = mean
|
fitness_criterion = mean
|
||||||
fitness_threshold = 100000
|
fitness_threshold = 10000
|
||||||
pop_size = 100
|
pop_size = 100
|
||||||
reset_on_extinction = False
|
reset_on_extinction = False
|
||||||
|
|
||||||
@ -44,8 +44,8 @@ node_add_prob = 0.2
|
|||||||
node_delete_prob = 0.2
|
node_delete_prob = 0.2
|
||||||
|
|
||||||
# network parameters
|
# network parameters
|
||||||
num_hidden = 5
|
num_hidden = 3
|
||||||
num_inputs = 200
|
num_inputs = 216
|
||||||
num_outputs = 6
|
num_outputs = 6
|
||||||
|
|
||||||
# node response options
|
# node response options
|
||||||
|
|||||||
27
src/ai/evaluations.py
Executable file → Normal file
27
src/ai/evaluations.py
Executable file → Normal file
@ -22,32 +22,20 @@ def eval_genome(genome: neat.DefaultGenome, config: neat.Config) -> float:
|
|||||||
|
|
||||||
while not game.game_over:
|
while not game.game_over:
|
||||||
current_figure: list[int] = [
|
current_figure: list[int] = [
|
||||||
component for vec in game.tetromino.figure.value.shape for component in vec
|
|
||||||
]
|
|
||||||
|
|
||||||
current_figure_pos: list[int] = [
|
|
||||||
component
|
component
|
||||||
for block in game.tetromino.blocks
|
for block in game.tetromino.blocks
|
||||||
for component in (int(block.pos.x), int(block.pos.y))
|
for component in (int(block.pos.x), int(block.pos.y))
|
||||||
]
|
]
|
||||||
|
|
||||||
next_figures: list[int] = [
|
next_figure: list[int] = [
|
||||||
component
|
vec
|
||||||
for figure in app.next_figures
|
for vec in app.game.get_next_figure().value.shape
|
||||||
for pos in figure.value.shape
|
for vec in (int(vec.x), int(vec.y))
|
||||||
for component in pos
|
|
||||||
]
|
]
|
||||||
|
|
||||||
field: np.ndarray = np.zeros((CONFIG.game.rows, CONFIG.game.columns), dtype=int)
|
field = np.where(game.field != None, 1, 0)
|
||||||
|
|
||||||
block: Block
|
output = net.activate((*next_figure, *current_figure, *field.flatten()))
|
||||||
for block in game.sprites:
|
|
||||||
field[int(block.pos.y), int(block.pos.x)] = 1
|
|
||||||
|
|
||||||
output = net.activate(
|
|
||||||
# (*current_figure, *current_figure_pos, *next_figures, *field.flatten())
|
|
||||||
field.flatten()
|
|
||||||
)
|
|
||||||
|
|
||||||
decision = output.index(max(output))
|
decision = output.index(max(output))
|
||||||
|
|
||||||
@ -64,7 +52,8 @@ def eval_genome(genome: neat.DefaultGenome, config: neat.Config) -> float:
|
|||||||
app.run_game_loop()
|
app.run_game_loop()
|
||||||
moves += 1
|
moves += 1
|
||||||
|
|
||||||
genome.fitness = calculate_fitness(game, field) - moves / 10
|
fitness = calculate_fitness(game)
|
||||||
|
genome.fitness = fitness - fitness / moves
|
||||||
score, lines, level = app.game.score, app.game.lines, app.game.level
|
score, lines, level = app.game.score, app.game.lines, app.game.level
|
||||||
|
|
||||||
log.debug(f"{genome.fitness=:<+6.6}\t{score=:<6} {lines=:<6} {level=:<6}")
|
log.debug(f"{genome.fitness=:<+6.6}\t{score=:<6} {lines=:<6} {level=:<6}")
|
||||||
|
|||||||
@ -8,26 +8,32 @@ from utils import CONFIG
|
|||||||
from .log import log
|
from .log import log
|
||||||
|
|
||||||
|
|
||||||
def calculate_fitness(game: Game, field: Optional[np.ndarray] = None) -> float:
|
def calculate_fitness(game: Game) -> float:
|
||||||
line_values = _calc_line_values(field)
|
field = np.where(game.field != None, 1, 0)
|
||||||
return game.score * 10.0 + line_values
|
reward, penalty = _calc_height_penalty(field)
|
||||||
|
fitness = game.score * 100 - _calc_holes(field) - penalty + reward
|
||||||
|
return fitness
|
||||||
|
|
||||||
|
|
||||||
def _calc_line_values(field: Optional[np.ndarray]) -> int:
|
def _calc_holes(field: np.ndarray) -> float:
|
||||||
if field is None:
|
height, width = field.shape
|
||||||
return 0
|
penalty = 0
|
||||||
|
|
||||||
line_values = 0
|
for col in range(width):
|
||||||
for idx, line in enumerate(np.flipud(field), start=1):
|
column = field[:, col]
|
||||||
if idx <= 4:
|
holde_indices = np.where(column == 0)[0]
|
||||||
line_values += int(line.sum()) * 5
|
|
||||||
elif idx <= 8:
|
|
||||||
line_values += int(line.sum()) * 3
|
|
||||||
elif idx <= 12:
|
|
||||||
line_values += int(line.sum()) * 0
|
|
||||||
elif idx <= 16:
|
|
||||||
line_values += int(line.sum()) * -5
|
|
||||||
else:
|
|
||||||
line_values += int(line.sum()) * -10
|
|
||||||
|
|
||||||
return line_values
|
if len(holde_indices) > 0:
|
||||||
|
highest_hole = holde_indices[0]
|
||||||
|
penalty += np.sum(field[highest_hole:, col]) * (height - highest_hole)
|
||||||
|
return penalty
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_height_penalty(field: np.ndarray) -> float:
|
||||||
|
column_heights = np.max(
|
||||||
|
np.where(field == 1, field.shape[0] - np.arange(field.shape[0])[:, None], 0),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
reward = np.mean(1 / (column_heights + 1))
|
||||||
|
penalty = np.mean(column_heights * np.arange(1, field.shape[1] + 1))
|
||||||
|
return reward, penalty
|
||||||
|
|||||||
@ -23,10 +23,10 @@ def train(gen_count: int, parallel: int = 1) -> None:
|
|||||||
chekpoint_path = BASE_PATH / "checkpoints"
|
chekpoint_path = BASE_PATH / "checkpoints"
|
||||||
plots_path = BASE_PATH / "plots"
|
plots_path = BASE_PATH / "plots"
|
||||||
|
|
||||||
population = neat.Checkpointer().restore_checkpoint(
|
# population = neat.Checkpointer().restore_checkpoint(
|
||||||
BASE_PATH / "checkpoints" / "neat-checkpoint-44"
|
# BASE_PATH / "checkpoints" / "neat-checkpoint-199"
|
||||||
)
|
# )
|
||||||
# population = neat.Population(config)
|
population = neat.Population(config)
|
||||||
population.add_reporter(neat.StdOutReporter(True))
|
population.add_reporter(neat.StdOutReporter(True))
|
||||||
stats = neat.StatisticsReporter()
|
stats = neat.StatisticsReporter()
|
||||||
population.add_reporter(stats)
|
population.add_reporter(stats)
|
||||||
@ -39,9 +39,9 @@ def train(gen_count: int, parallel: int = 1) -> None:
|
|||||||
stats,
|
stats,
|
||||||
ylog=False,
|
ylog=False,
|
||||||
view=False,
|
view=False,
|
||||||
filename=plots_path / "avg_fitness.svg",
|
filename=plots_path / "avg_fitness.png",
|
||||||
)
|
)
|
||||||
plot_species(stats, view=False, filename=plots_path / "speciation.svg")
|
plot_species(stats, view=False, filename=plots_path / "speciation.png")
|
||||||
|
|
||||||
log.info("Saving best genome")
|
log.info("Saving best genome")
|
||||||
save_genome(winner)
|
save_genome(winner)
|
||||||
|
|||||||
@ -101,9 +101,9 @@ class Main:
|
|||||||
Returns:
|
Returns:
|
||||||
The next figure in the sequence.
|
The next figure in the sequence.
|
||||||
"""
|
"""
|
||||||
next_shape = self.next_figures.pop(0)
|
next_figure = self.next_figures.pop(0)
|
||||||
self.next_figures.append(Figure.random())
|
self.next_figures.append(Figure.random())
|
||||||
return next_shape
|
return next_figure
|
||||||
|
|
||||||
def _initialize_pygeme(self) -> None:
|
def _initialize_pygeme(self) -> None:
|
||||||
"""Initialize Pygame and set up the display."""
|
"""Initialize Pygame and set up the display."""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user