mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
delete ai
This commit is contained in:
parent
77cd153b10
commit
06f962f5c7
79
config
79
config
@ -1,79 +0,0 @@
|
||||
[NEAT]
|
||||
fitness_criterion = max
|
||||
fitness_threshold = 500
|
||||
pop_size = 50
|
||||
reset_on_extinction = False
|
||||
|
||||
[DefaultGenome]
|
||||
# node activation options
|
||||
activation_default = relu
|
||||
activation_mutate_rate = 0.0
|
||||
activation_options = relu
|
||||
|
||||
# node aggregation options
|
||||
aggregation_default = sum
|
||||
aggregation_mutate_rate = 0.0
|
||||
aggregation_options = sum
|
||||
|
||||
# node bias options
|
||||
bias_init_mean = 0.0
|
||||
bias_init_stdev = 1.0
|
||||
bias_max_value = 30.0
|
||||
bias_min_value = -30.0
|
||||
bias_mutate_power = 0.5
|
||||
bias_mutate_rate = 0.7
|
||||
bias_replace_rate = 0.1
|
||||
|
||||
# genome compatibility options
|
||||
compatibility_disjoint_coefficient = 1.0
|
||||
compatibility_weight_coefficient = 0.5
|
||||
|
||||
# connection add/remove rates
|
||||
conn_add_prob = 0.5
|
||||
conn_delete_prob = 0.5
|
||||
|
||||
# connection enable options
|
||||
enabled_default = True
|
||||
enabled_mutate_rate = 0.01
|
||||
|
||||
feed_forward = True
|
||||
initial_connection = full_direct
|
||||
|
||||
# node add/remove rates
|
||||
node_add_prob = 0.2
|
||||
node_delete_prob = 0.2
|
||||
|
||||
# network parameters
|
||||
num_hidden = 1
|
||||
num_inputs = 200
|
||||
num_outputs = 6
|
||||
|
||||
# node response options
|
||||
response_init_mean = 1.0
|
||||
response_init_stdev = 0.0
|
||||
response_max_value = 30.0
|
||||
response_min_value = -30.0
|
||||
response_mutate_power = 0.0
|
||||
response_mutate_rate = 0.0
|
||||
response_replace_rate = 0.0
|
||||
|
||||
# connection weight options
|
||||
weight_init_mean = 0.0
|
||||
weight_init_stdev = 1.0
|
||||
weight_max_value = 30
|
||||
weight_min_value = -30
|
||||
weight_mutate_power = 0.5
|
||||
weight_mutate_rate = 0.8
|
||||
weight_replace_rate = 0.1
|
||||
|
||||
[DefaultSpeciesSet]
|
||||
compatibility_threshold = 3.0
|
||||
|
||||
[DefaultStagnation]
|
||||
species_fitness_func = max
|
||||
max_stagnation = 20
|
||||
species_elitism = 2
|
||||
|
||||
[DefaultReproduction]
|
||||
elitism = 2
|
||||
survival_threshold = 0.2
|
||||
24
main.py
24
main.py
@ -4,14 +4,6 @@ import argparse
|
||||
from loguru import logger
|
||||
from utils import BASE_PATH, CONFIG, GameMode
|
||||
|
||||
|
||||
def pos_int(string: str) -> int:
|
||||
ivalue = int(string)
|
||||
if ivalue <= 0:
|
||||
raise argparse.ArgumentTypeError(f"{ivalue} is not a positive integer")
|
||||
return ivalue
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Tetris game with AI")
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument(
|
||||
@ -28,16 +20,6 @@ group.add_argument(
|
||||
help="Verbose",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--train",
|
||||
type=pos_int,
|
||||
nargs=2,
|
||||
metavar=("n generations", "n parallels"),
|
||||
help="Trains the AI",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--graphic",
|
||||
@ -61,14 +43,8 @@ def main(args: argparse.ArgumentParser) -> None:
|
||||
elif args.verbose:
|
||||
CONFIG.log_level = "info"
|
||||
|
||||
import ai
|
||||
import game
|
||||
|
||||
if args.train is not None:
|
||||
ai.log.debug("Training the AI")
|
||||
ai.train(*args.train)
|
||||
# game.Menu(GameMode.AI_TRAINING).run()
|
||||
else:
|
||||
game.log.debug("Running the game")
|
||||
game.Main(GameMode.PLAYER).run()
|
||||
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
from .io import load_genome, save_genome
|
||||
from .log import log
|
||||
from .training import train
|
||||
|
||||
__all__ = ["log", "train", "load_genome", "save_genome"]
|
||||
@ -1,64 +0,0 @@
|
||||
import math
|
||||
import time
|
||||
|
||||
import neat
|
||||
import numpy as np
|
||||
import pygame
|
||||
from game import Main
|
||||
from game.sprites import Block
|
||||
from utils import CONFIG, GameMode
|
||||
|
||||
# from .fitness import calculate_fitness
|
||||
from .log import log
|
||||
from .moves import calculate_fitness
|
||||
|
||||
|
||||
def eval_genome(genome: neat.DefaultGenome, config: neat.Config) -> float:
|
||||
app = Main(GameMode.AI_TRAINING).play()
|
||||
|
||||
game = app.game
|
||||
tetris = game.tetris
|
||||
net = neat.nn.FeedForwardNetwork.create(genome, config)
|
||||
genome.fitness = 0
|
||||
|
||||
while not tetris.game_over:
|
||||
# current_figure: list[int] = [
|
||||
# component
|
||||
# for block in tetris.tetromino.blocks
|
||||
# for component in (int(block.pos.x), int(block.pos.y))
|
||||
# ]
|
||||
|
||||
# next_figure: list[int] = [
|
||||
# vec
|
||||
# for vec in game.next_figure.value.shape
|
||||
# for vec in (int(vec.x), int(vec.y))
|
||||
# ]
|
||||
|
||||
field = np.where(tetris.field != None, 1, 0)
|
||||
|
||||
for block in tetris.tetromino.blocks:
|
||||
field[int(block.pos.y), int(block.pos.x)] = 2
|
||||
|
||||
output = net.activate(field.flatten())
|
||||
|
||||
decision = output.index(max(output))
|
||||
|
||||
decisions = {
|
||||
0: tetris.move_left,
|
||||
1: tetris.move_right,
|
||||
2: tetris.move_down,
|
||||
3: tetris.rotate,
|
||||
4: tetris.rotate_reverse,
|
||||
5: tetris.drop,
|
||||
}
|
||||
|
||||
decisions[decision]()
|
||||
app.run_game_loop()
|
||||
|
||||
genome.fitness = calculate_fitness(field)
|
||||
score, lines, level = tetris.score, tetris.lines, tetris.level
|
||||
|
||||
log.debug(f"{genome.fitness=:<+6.6}\t{score=:<6} {lines=:<6} {level=:<6}")
|
||||
|
||||
tetris.restart()
|
||||
return genome.fitness
|
||||
@ -1,3 +0,0 @@
|
||||
from .calculate import calculate_fitness
|
||||
|
||||
__all__ = ["calculate_fitness"]
|
||||
@ -1,30 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .peaks import get_peaks
|
||||
|
||||
|
||||
def get_bumpiness(
|
||||
*, peaks: Optional[np.ndarray] = None, field: Optional[np.ndarray] = None
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the bumpiness of a given signal based on peaks.
|
||||
|
||||
Args:
|
||||
peaks: Array containing peak indices. If not provided, it will be computed from the field.
|
||||
field: The signal field. Required if peaks is not provided.
|
||||
|
||||
Returns:
|
||||
The bumpiness of the signal.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `peaks` and `field` are `None`.
|
||||
"""
|
||||
if peaks is None and field is None:
|
||||
raise ValueError("peaks and field cannot both be None")
|
||||
elif peaks is None:
|
||||
peaks = get_peaks(field)
|
||||
|
||||
differences = np.abs(np.diff(peaks))
|
||||
return int(np.sum(differences))
|
||||
@ -1,50 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai.log import log
|
||||
|
||||
from .bumpiness import get_bumpiness
|
||||
from .holes import get_holes, get_holes_sum
|
||||
from .peaks import get_peaks, get_peaks_max, get_peaks_sum
|
||||
from .transitions import get_col_transition, get_row_transition
|
||||
from .wells import get_wells, get_wells_max
|
||||
|
||||
|
||||
def calculate_fitness(field: np.ndarray) -> float:
|
||||
"""
|
||||
Calculate the fitness value for the given field.
|
||||
|
||||
Args:
|
||||
field: The game field.
|
||||
|
||||
Returns:
|
||||
The fitness value.
|
||||
"""
|
||||
peaks = get_peaks(field=field)
|
||||
holes = get_holes(field=field)
|
||||
highest_peak = get_peaks_max(peaks=peaks)
|
||||
wells = get_wells(peaks=peaks)
|
||||
|
||||
agg_height = get_peaks_sum(peaks=peaks)
|
||||
n_holes = get_holes_sum(field=field)
|
||||
bumpiness = get_bumpiness(peaks=peaks)
|
||||
num_pits = np.count_nonzero(np.count_nonzero(field, axis=0) == 0)
|
||||
max_wells = get_wells_max(wells=wells)
|
||||
n_cols_with_holes = np.count_nonzero(np.array(holes) > 0)
|
||||
row_transitions = get_row_transition(field=field, highest_peak=highest_peak)
|
||||
col_transitions = get_col_transition(field=field, peaks=peaks)
|
||||
cleared = np.count_nonzero(np.mean(field, axis=1))
|
||||
|
||||
fitness = (
|
||||
agg_height
|
||||
+ n_holes
|
||||
+ bumpiness
|
||||
+ num_pits
|
||||
+ max_wells
|
||||
+ n_cols_with_holes
|
||||
+ row_transitions
|
||||
+ col_transitions
|
||||
+ cleared
|
||||
)
|
||||
return -float(fitness)
|
||||
@ -1,56 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .peaks import get_peaks
|
||||
|
||||
|
||||
def get_holes(
|
||||
field: np.ndarray,
|
||||
peaks: Optional[np.array] = None,
|
||||
) -> np.array:
|
||||
"""
|
||||
Calculate the number of holes in each column of the given field.
|
||||
|
||||
Args:
|
||||
field: The signal field.
|
||||
peaks: Array containing peak indices. If not provided, it will be computed from the field.
|
||||
|
||||
Returns:
|
||||
Array containing the number of holes in each column.
|
||||
"""
|
||||
if peaks is None:
|
||||
peaks = get_peaks(field)
|
||||
col_count = field.shape[1]
|
||||
holes = np.zeros(col_count, dtype=int)
|
||||
|
||||
for col in range(col_count):
|
||||
start = -peaks[col]
|
||||
if start != 0:
|
||||
holes[col] = np.count_nonzero(field[int(start) :, col] == 0)
|
||||
|
||||
return holes
|
||||
|
||||
|
||||
def get_holes_sum(
|
||||
*, holes: Optional[np.ndarray] = None, field: Optional[np.ndarray] = None
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the total number of holes in the given field or use pre-computed holes.
|
||||
|
||||
Args:
|
||||
holes: Array containing the number of holes in each column. If not provided, it will be computed from the field.
|
||||
field: The signal field. Required if holes is not provided.
|
||||
|
||||
Returns:
|
||||
The total number of holes in the field.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `holes` and `field` are `None`.
|
||||
"""
|
||||
if holes is None and field is None:
|
||||
raise ValueError("holes and field cannot both be None")
|
||||
elif holes is None:
|
||||
holes = get_holes(field)
|
||||
|
||||
return int(np.sum(holes))
|
||||
@ -1,63 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_peaks(field: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Find the peaks in each column of the given field.
|
||||
|
||||
Args:
|
||||
field: The signal field.
|
||||
|
||||
Returns:
|
||||
Array containing the indices of the peaks in each column.
|
||||
"""
|
||||
peaks = np.where(field == 1, field.shape[0] - np.argmax(field, axis=0), 0)
|
||||
return peaks.max(axis=0)
|
||||
|
||||
|
||||
def get_peaks_max(
|
||||
*, peaks: Optional[np.ndarray] = None, field: Optional[np.ndarray] = None
|
||||
) -> int:
|
||||
"""
|
||||
Get the maximum peak value from the provided peaks or compute peaks from the field.
|
||||
|
||||
Args:
|
||||
peaks: Array containing the indices of the peaks in each column. If not provided, it will be computed from the field.
|
||||
field: The signal field. Required if peaks is not provided.
|
||||
|
||||
Returns:
|
||||
The maximum peak value.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `peaks` and `field` are `None`.
|
||||
"""
|
||||
if peaks is None and field is None:
|
||||
raise ValueError("peaks and field cannot both be None")
|
||||
elif peaks is None:
|
||||
peaks = get_peaks(field)
|
||||
return int(np.max(peaks))
|
||||
|
||||
|
||||
def get_peaks_sum(
|
||||
*, peaks: Optional[np.ndarray] = None, field: Optional[np.ndarray] = None
|
||||
) -> int:
|
||||
"""
|
||||
Get the sum of peak values from the provided peaks or compute peaks from the field.
|
||||
|
||||
Args:
|
||||
peaks: Array containing the indices of the peaks in each column. If not provided, it will be computed from the field.
|
||||
field: The signal field. Required if peaks is not provided.
|
||||
|
||||
Returns:
|
||||
The sum of peak values.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `peaks` and `field` are `None`.
|
||||
"""
|
||||
if peaks is None and field is None:
|
||||
raise ValueError("peaks and field cannot both be None")
|
||||
elif peaks is None:
|
||||
peaks = get_peaks(field)
|
||||
return np.sum(peaks)
|
||||
@ -1,53 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .peaks import get_peaks, get_peaks_max
|
||||
|
||||
|
||||
def get_row_transition(field: np.ndarray, highest_peak: Optional[int] = None) -> int:
|
||||
"""
|
||||
Calculate the number of transitions in the rows of the given field.
|
||||
|
||||
Args:
|
||||
field: The signal field.
|
||||
highest_peak: The highest peak value. If not provided, it will be computed from the field.
|
||||
|
||||
Returns:
|
||||
The total number of transitions in the rows.
|
||||
"""
|
||||
if highest_peak is None:
|
||||
highest_peak = get_peaks_max(field=field)
|
||||
|
||||
rows_to_check = slice(int(field.shape[0] - highest_peak), field.shape[0])
|
||||
transitions = np.sum(field[rows_to_check, 1:] != field[rows_to_check, :-1])
|
||||
|
||||
return int(transitions)
|
||||
|
||||
|
||||
def get_col_transition(field: np.ndarray, peaks: Optional[np.ndarray] = None) -> int:
|
||||
"""
|
||||
Calculate the number of transitions in the columns of the given field.
|
||||
|
||||
Args:
|
||||
field: The signal field.
|
||||
peaks: Array containing the indices of the peaks in each column. If not provided, it will be computed from the field.
|
||||
|
||||
Returns:
|
||||
The total number of transitions in the columns.
|
||||
"""
|
||||
if peaks is None:
|
||||
peaks = get_peaks(field)
|
||||
|
||||
transitions_sum = 0
|
||||
|
||||
for col in range(field.shape[1]):
|
||||
if peaks[col] <= 1:
|
||||
continue
|
||||
|
||||
col_values = field[int(field.shape[0] - peaks[col]) : field.shape[0], col]
|
||||
transitions = np.sum(col_values[:-1] != col_values[1:])
|
||||
|
||||
transitions_sum += transitions
|
||||
|
||||
return transitions_sum
|
||||
@ -1,69 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .peaks import get_peaks
|
||||
|
||||
|
||||
def get_wells(
|
||||
*, peaks: Optional[np.ndarray] = None, field: Optional[np.ndarray] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Calculate the well depths in each column of the given field.
|
||||
|
||||
Args:
|
||||
peaks: Array containing the indices of the peaks in each column. If not provided, it will be computed from the field.
|
||||
field: The signal field. Required if peaks is not provided.
|
||||
|
||||
Returns:
|
||||
Array containing the well depths in each column.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `peaks` and `field` are `None`.
|
||||
"""
|
||||
if peaks is None and field is None:
|
||||
raise ValueError("peaks and field cannot both be None")
|
||||
elif peaks is None:
|
||||
peaks = get_peaks(field)
|
||||
|
||||
wells = np.zeros_like(peaks)
|
||||
|
||||
first_well = peaks[1] - peaks[0]
|
||||
wells[0] = first_well if first_well > 0 else 0
|
||||
|
||||
last_well = peaks[-2] - peaks[-1]
|
||||
wells[-1] = last_well if last_well > 0 else 0
|
||||
|
||||
for idx in range(1, len(peaks) - 1):
|
||||
well_l = peaks[idx - 1] - peaks[idx]
|
||||
well_l = well_l if well_l > 0 else 0
|
||||
|
||||
well_r = peaks[idx + 1] - peaks[idx]
|
||||
well_r = well_r if well_r > 0 else 0
|
||||
|
||||
wells[idx] = well_l if well_l >= well_r else well_r
|
||||
|
||||
return wells
|
||||
|
||||
|
||||
def get_wells_max(
|
||||
*, wells: Optional[np.ndarray] = None, field: Optional[np.ndarray] = None
|
||||
) -> int:
|
||||
"""
|
||||
Get the maximum well depth from the provided wells or compute wells from the field.
|
||||
|
||||
Args:
|
||||
wells: Array containing the well depths in each column. If not provided, it will be computed from the field.
|
||||
field: The signal field. Required if wells is not provided.
|
||||
|
||||
Returns:
|
||||
The maximum well depth.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `wells` and `field` are `None`.
|
||||
"""
|
||||
if wells is None and field is None:
|
||||
raise ValueError("wells and field cannot both be None")
|
||||
elif wells is None:
|
||||
wells = get_wells(field)
|
||||
return int(np.max(wells))
|
||||
25
src/ai/io.py
25
src/ai/io.py
@ -1,25 +0,0 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import neat
|
||||
from utils import CONFIG
|
||||
|
||||
|
||||
def get_config() -> neat.Config:
|
||||
return neat.Config(
|
||||
neat.DefaultGenome,
|
||||
neat.DefaultReproduction,
|
||||
neat.DefaultSpeciesSet,
|
||||
neat.DefaultStagnation,
|
||||
CONFIG.ai.config_path,
|
||||
)
|
||||
|
||||
|
||||
def load_genome() -> neat.DefaultGenome:
|
||||
with open(CONFIG.ai.winner_path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def save_genome(genome: neat.DefaultGenome) -> None:
|
||||
with open(CONFIG.ai.winner_path, "wb") as f:
|
||||
pickle.dump(genome, f)
|
||||
@ -1,13 +0,0 @@
|
||||
from loguru import logger
|
||||
from utils import BASE_PATH, CONFIG
|
||||
|
||||
log = logger.bind(name="ai")
|
||||
|
||||
log.add(
|
||||
BASE_PATH / ".logs" / "ai.log",
|
||||
format="{time} | {level} | {message}",
|
||||
level=CONFIG.log_level.upper(),
|
||||
rotation="10 MB",
|
||||
compression="zip",
|
||||
filter=lambda record: record["extra"].get("name") == "ai",
|
||||
)
|
||||
@ -1,3 +0,0 @@
|
||||
from .calculate import calculate_fitness
|
||||
|
||||
__all__ = ["calculate_fitness"]
|
||||
@ -1,16 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def bumpiness(
|
||||
field: np.ndarray[int, np.dtype[np.uint8]],
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the bumpiness of a given signal based on peaks.
|
||||
|
||||
Args:
|
||||
field: The game field.
|
||||
|
||||
Returns:
|
||||
The bumpiness of the field.
|
||||
"""
|
||||
return int(np.sum(np.abs(np.diff(field.shape[0] - np.argmax(field, axis=0)))))
|
||||
@ -1,35 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ai.log import log
|
||||
|
||||
from .bumpiness import bumpiness
|
||||
from .height import aggregate_height
|
||||
from .holes import holes
|
||||
from .lines import complete_lines
|
||||
|
||||
|
||||
def calculate_fitness(field: np.ndarray) -> float:
|
||||
"""
|
||||
Calculate the fitness value for the given field.
|
||||
|
||||
Args:
|
||||
field: The game field.
|
||||
|
||||
Returns:
|
||||
The fitness value.
|
||||
"""
|
||||
|
||||
height_w = aggregate_height(field)
|
||||
holes_w = holes(field)
|
||||
bumpiness_w = bumpiness(field)
|
||||
lines_w = complete_lines(field)
|
||||
|
||||
fitness = (
|
||||
-0.510066 * height_w
|
||||
+ 0.760666 * lines_w
|
||||
- 0.35663 * holes_w
|
||||
- 0.184483 * bumpiness_w
|
||||
)
|
||||
return fitness
|
||||
@ -1,14 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def aggregate_height(field: np.ndarray[int, np.dtype[np.uint8]]) -> int:
|
||||
"""
|
||||
Calculates the aggregate height of the field.
|
||||
|
||||
Args:
|
||||
field: 2D array representing the game field.
|
||||
|
||||
Returns:
|
||||
The aggregate height of the field.
|
||||
"""
|
||||
return int(np.sum(field.shape[0] - np.argmax(field, axis=0)))
|
||||
@ -1,24 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def holes(
|
||||
field: np.ndarray[int, np.dtype[np.uint8]],
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the number of holes in each column of the given field.
|
||||
|
||||
Args:
|
||||
field: The signal field.
|
||||
peaks: Array containing peak indices. If not provided, it will be computed from the field.
|
||||
|
||||
Returns:
|
||||
The total number of holes in the field.
|
||||
"""
|
||||
|
||||
first_nonzero_indices = np.argmax(field != 0, axis=0)
|
||||
|
||||
mask = (field == 0) & (
|
||||
np.arange(field.shape[0])[:, np.newaxis] > first_nonzero_indices
|
||||
)
|
||||
|
||||
return int(np.sum(mask))
|
||||
@ -1,14 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def complete_lines(field: np.ndarray[int, np.dtype[np.uint8]]) -> int:
|
||||
"""
|
||||
Calculates the number of complete lines in the field.
|
||||
|
||||
Args:
|
||||
field: 2D array representing the game field.
|
||||
|
||||
Returns:
|
||||
The number of complete lines in the field.
|
||||
"""
|
||||
return int(np.sum(np.all(field, axis=1)))
|
||||
@ -1,65 +0,0 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import neat
|
||||
import pygame
|
||||
from game import Main
|
||||
from utils import BASE_PATH, CONFIG
|
||||
|
||||
from .evaluations import eval_genome
|
||||
from .io import get_config, save_genome
|
||||
from .log import log
|
||||
from .visualize import draw_net, plot_species, plot_stats
|
||||
|
||||
|
||||
def train(
|
||||
gen_count: int = CONFIG.ai.generations,
|
||||
parallel: int = CONFIG.ai.parallels,
|
||||
checkpoint_path: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Train the AI
|
||||
Args:
|
||||
gen_count: Number of generations to train (default is 200).
|
||||
threads: Number of threads to use (default is 1).
|
||||
checkpoint_path: Path to a checkpoint file to resume training from.
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
population = (
|
||||
neat.Checkpointer().restore_checkpoint(checkpoint_path)
|
||||
if checkpoint_path
|
||||
else neat.Population(config)
|
||||
)
|
||||
population.add_reporter(neat.StdOutReporter(True))
|
||||
stats = neat.StatisticsReporter()
|
||||
population.add_reporter(stats)
|
||||
population.add_reporter(
|
||||
neat.Checkpointer(
|
||||
CONFIG.ai.checkpoint.generation_interval,
|
||||
CONFIG.ai.checkpoint.time_interval,
|
||||
CONFIG.ai.checkpoint.filename_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
pe = neat.ParallelEvaluator(parallel, eval_genome)
|
||||
|
||||
winner = population.run(pe.evaluate, gen_count)
|
||||
plot_stats(
|
||||
stats,
|
||||
ylog=False,
|
||||
view=False,
|
||||
filename=CONFIG.ai.plot_path / "avg_fitness.png",
|
||||
)
|
||||
plot_species(stats, view=False, filename=CONFIG.ai.plot_path / "speciation.png")
|
||||
draw_net(config, winner, view=False, filename=CONFIG.ai.plot_path / "network.gv")
|
||||
draw_net(
|
||||
config,
|
||||
winner,
|
||||
view=False,
|
||||
filename=CONFIG.ai.plot_path / "network-pruned.gv",
|
||||
prune_unused=True,
|
||||
)
|
||||
|
||||
log.info("Saving best genome")
|
||||
save_genome(winner)
|
||||
@ -1,141 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import neat
|
||||
import numpy as np
|
||||
|
||||
|
||||
def plot_stats(
|
||||
statistics: neat.StatisticsReporter,
|
||||
ylog: bool = False,
|
||||
view: bool = False,
|
||||
filename: str | Path = "avg_fitness.svg",
|
||||
):
|
||||
"""Plots the population's average and best fitness."""
|
||||
|
||||
generation = range(len(statistics.most_fit_genomes))
|
||||
best_fitness = [c.fitness for c in statistics.most_fit_genomes]
|
||||
avg_fitness = np.array(statistics.get_fitness_mean())
|
||||
stdev_fitness = np.array(statistics.get_fitness_stdev())
|
||||
|
||||
plt.plot(generation, avg_fitness, "b-", label="average")
|
||||
plt.plot(generation, avg_fitness - stdev_fitness, "g-.", label="-1 sd")
|
||||
plt.plot(generation, avg_fitness + stdev_fitness, "g-.", label="+1 sd")
|
||||
plt.plot(generation, best_fitness, "r-", label="best")
|
||||
|
||||
plt.title("Population's average and best fitness")
|
||||
plt.xlabel("Generations")
|
||||
plt.ylabel("Fitness")
|
||||
plt.grid()
|
||||
plt.legend(loc="best")
|
||||
if ylog:
|
||||
plt.gca().set_yscale("symlog")
|
||||
|
||||
plt.savefig(str(filename))
|
||||
if view:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_species(
|
||||
statistics: neat.StatisticsReporter,
|
||||
view: bool = False,
|
||||
filename: str | Path = "speciation.svg",
|
||||
):
|
||||
"""Visualizes speciation throughout evolution."""
|
||||
|
||||
species_sizes = statistics.get_species_sizes()
|
||||
num_generations = len(species_sizes)
|
||||
curves = np.array(species_sizes).T
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.stackplot(range(num_generations), *curves)
|
||||
|
||||
plt.title("Speciation")
|
||||
plt.ylabel("Size per Species")
|
||||
plt.xlabel("Generations")
|
||||
|
||||
plt.savefig(str(filename))
|
||||
|
||||
if view:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def draw_net(
|
||||
config: neat.Config,
|
||||
genome: neat.DefaultGenome,
|
||||
view: bool = False,
|
||||
filename: str | Path = None,
|
||||
node_names: dict = None,
|
||||
show_disabled: bool = True,
|
||||
prune_unused: bool = False,
|
||||
node_colors: dict = None,
|
||||
fmt: str = "svg",
|
||||
):
|
||||
"""Receives a genome and draws a neural network with arbitrary topology."""
|
||||
|
||||
# If requested, use a copy of the genome which omits all components that won't affect the output.
|
||||
if prune_unused:
|
||||
if show_disabled:
|
||||
warnings.warn("show_disabled has no effect when prune_unused is True")
|
||||
|
||||
genome = genome.get_pruned_copy(config.genome_config)
|
||||
|
||||
if node_names is None:
|
||||
node_names = {}
|
||||
|
||||
assert type(node_names) is dict
|
||||
|
||||
if node_colors is None:
|
||||
node_colors = {}
|
||||
|
||||
assert type(node_colors) is dict
|
||||
|
||||
node_attrs = {"shape": "circle", "fontsize": "9", "height": "0.2", "width": "0.2"}
|
||||
|
||||
dot = graphviz.Digraph(format=fmt, node_attr=node_attrs)
|
||||
|
||||
inputs = set()
|
||||
for k in config.genome_config.input_keys:
|
||||
inputs.add(k)
|
||||
name = node_names.get(k, str(k))
|
||||
input_attrs = {
|
||||
"style": "filled",
|
||||
"shape": "box",
|
||||
"fillcolor": node_colors.get(k, "lightgray"),
|
||||
}
|
||||
dot.node(name, _attributes=input_attrs)
|
||||
|
||||
outputs = set()
|
||||
for k in config.genome_config.output_keys:
|
||||
outputs.add(k)
|
||||
name = node_names.get(k, str(k))
|
||||
node_attrs = {"style": "filled", "fillcolor": node_colors.get(k, "lightblue")}
|
||||
|
||||
dot.node(name, _attributes=node_attrs)
|
||||
|
||||
for n in genome.nodes.keys():
|
||||
if n in inputs or n in outputs:
|
||||
continue
|
||||
|
||||
attrs = {"style": "filled", "fillcolor": node_colors.get(n, "white")}
|
||||
dot.node(str(n), _attributes=attrs)
|
||||
|
||||
for cg in genome.connections.values():
|
||||
if cg.enabled or show_disabled:
|
||||
input, output = cg.key
|
||||
a = node_names.get(input, str(input))
|
||||
b = node_names.get(output, str(output))
|
||||
style = "solid" if cg.enabled else "dotted"
|
||||
color = "green" if cg.weight > 0 else "red"
|
||||
width = str(0.1 + abs(cg.weight / 5.0))
|
||||
dot.edge(
|
||||
a, b, _attributes={"style": style, "color": color, "penwidth": width}
|
||||
)
|
||||
|
||||
dot.render(filename, view=view)
|
||||
|
||||
return dot
|
||||
@ -1,70 +0,0 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from ai.fitness.bumpiness import get_bumpiness
|
||||
from ai.fitness.holes import holes
|
||||
from ai.fitness.peaks import get_peaks_sum
|
||||
from ai.fitness.transitions import (
|
||||
get_col_transition,
|
||||
get_row_transition,
|
||||
)
|
||||
from ai.fitness.wells import get_wells
|
||||
|
||||
|
||||
class TestFitness(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.fields: tuple[np.ndarray] = (
|
||||
np.array(
|
||||
[
|
||||
[0, 1, 0, 0, 1],
|
||||
[1, 0, 0, 1, 0],
|
||||
[0, 1, 1, 0, 0],
|
||||
]
|
||||
),
|
||||
np.zeros((3, 5)),
|
||||
np.array(
|
||||
[
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
def test_get_peaks_sum(self) -> None:
|
||||
answers: tuple[int] = (11, 0, 2)
|
||||
for field, answer in zip(self.fields, answers):
|
||||
self.assertEqual(get_peaks_sum(field=field), answer)
|
||||
|
||||
def test_get_row_transistions(self) -> None:
|
||||
answers = (8, 0, 2)
|
||||
for field, answer in zip(self.fields, answers):
|
||||
self.assertEqual(get_row_transition(field), answer)
|
||||
|
||||
def test_get_col_transistions(self) -> None:
|
||||
answers = (5, 0, 1)
|
||||
for field, answer in zip(self.fields, answers):
|
||||
self.assertEqual(get_col_transition(field), answer)
|
||||
|
||||
def test_get_bumpiness(self):
|
||||
answers = (5, 0, 4)
|
||||
for field, answer in zip(self.fields, answers):
|
||||
self.assertEqual(get_bumpiness(field=field), answer)
|
||||
|
||||
def test_get_holes(self) -> None:
|
||||
answers = (
|
||||
np.array([1, 1, 0, 1, 2]),
|
||||
np.array([0, 0, 0, 0, 0]),
|
||||
np.array([0, 1, 0, 0, 0]),
|
||||
)
|
||||
for field, answer in zip(self.fields, answers):
|
||||
self.assertTrue(np.array_equal(holes(field), answer))
|
||||
|
||||
def test_get_wells(self) -> None:
|
||||
answers = (
|
||||
np.array([1, 0, 2, 1, 0]),
|
||||
np.array([0, 0, 0, 0, 0]),
|
||||
np.array([2, 0, 2, 0, 0]),
|
||||
)
|
||||
for field, answer in zip(self.fields, answers):
|
||||
self.assertTrue(np.array_equal(get_wells(field=field), answer))
|
||||
@ -1,33 +0,0 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from ai.moves.bumpiness import bumpiness
|
||||
from ai.moves.height import aggregate_height
|
||||
from ai.moves.holes import holes
|
||||
from ai.moves.lines import complete_lines
|
||||
|
||||
|
||||
class TestFitness(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.field = np.array(
|
||||
[
|
||||
[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
||||
[0, 1, 1, 1, 1, 1, 1, 0, 0, 1],
|
||||
[0, 1, 1, 0, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 0, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
]
|
||||
)
|
||||
|
||||
def test_aggregate_height(self) -> None:
|
||||
self.assertEqual(aggregate_height(self.field), 48)
|
||||
|
||||
def test_complete_lines(self) -> None:
|
||||
self.assertEqual(complete_lines(self.field), 2)
|
||||
|
||||
def test_holes(self) -> None:
|
||||
self.assertEqual(holes(self.field), 2)
|
||||
|
||||
def test_bumpiness(self) -> None:
|
||||
self.assertEqual(bumpiness(self.field), 6)
|
||||
Loading…
Reference in New Issue
Block a user