From 599f456c4617cd1b9bfb11a6be9b0c5227d1d393 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Fri, 5 Jan 2024 20:16:52 +0200 Subject: [PATCH] feat(ai): add `draw_net` --- src/ai/training.py | 10 ++++- src/ai/visualize.py | 104 ++++++++++++++++++++++++++++++-------------- 2 files changed, 80 insertions(+), 34 deletions(-) diff --git a/src/ai/training.py b/src/ai/training.py index 7231696..e656c2e 100644 --- a/src/ai/training.py +++ b/src/ai/training.py @@ -9,7 +9,7 @@ from utils import BASE_PATH, CONFIG from .evaluations import eval_genome from .io import get_config, save_genome from .log import log -from .visualize import plot_progress, plot_species, plot_stats +from .visualize import draw_net, plot_species, plot_stats def train( @@ -52,6 +52,14 @@ def train( filename=CONFIG.ai.plot_path / "avg_fitness.png", ) plot_species(stats, view=False, filename=CONFIG.ai.plot_path / "speciation.png") + draw_net(config, winner, view=False, filename=CONFIG.ai.plot_path / "network.gv") + draw_net( + config, + winner, + view=False, + filename=CONFIG.ai.plot_path / "network-pruned.gv", + prune_unused=True, + ) log.info("Saving best genome") save_genome(winner) diff --git a/src/ai/visualize.py b/src/ai/visualize.py index df44f2a..c2e1cba 100644 --- a/src/ai/visualize.py +++ b/src/ai/visualize.py @@ -4,8 +4,6 @@ import matplotlib.pyplot as plt import neat import numpy as np -from .log import log - def plot_stats( statistics: neat.StatisticsReporter, @@ -14,11 +12,6 @@ def plot_stats( filename: str | Path = "avg_fitness.svg", ): """Plots the population's average and best fitness.""" - if plt is None: - log.warning( - "This display is not available due to a missing optional dependency (matplotlib)" - ) - return generation = range(len(statistics.most_fit_genomes)) best_fitness = [c.fitness for c in statistics.most_fit_genomes] @@ -51,11 +44,6 @@ def plot_species( filename: str | Path = "speciation.svg", ): """Visualizes speciation throughout evolution.""" - if plt is None: - log.warning( - "This display is not available due to a missing optional dependency (matplotlib)" - ) - return species_sizes = statistics.get_species_sizes() num_generations = len(species_sizes) @@ -76,28 +64,78 @@ def plot_species( plt.close() -def plot_progress( - generations: list[int], - mean_fitness: list[int], - max_fitness: list[int], +def draw_net( + config: neat.Config, + genome: neat.DefaultGenome, view: bool = False, - filename: str | Path = "progress.svg", + filename: str | Path = None, + node_names: dict = None, + show_disabled: bool = True, + prune_unused: bool = False, + node_colors: dict = None, + fmt: str = "svg", ): - if plt is None: - log.warning( - "This display is not available due to a missing optional dependency (matplotlib)" - ) - return - plt.plot(generations, mean_fitness, label="Mean Fitness") - plt.plot(generations, max_fitness, label="Max Fitness") - plt.xlabel("Generations") - plt.ylabel("Fitness") - plt.title("NEAT Algorithm Progress") - plt.legend() - plt.grid(True) - plt.savefig(str(filename)) + """Receives a genome and draws a neural network with arbitrary topology.""" - if view: - plt.show() + # If requested, use a copy of the genome which omits all components that won't affect the output. + if prune_unused: + if show_disabled: + warnings.warn("show_disabled has no effect when prune_unused is True") - plt.close() + genome = genome.get_pruned_copy(config.genome_config) + + if node_names is None: + node_names = {} + + assert type(node_names) is dict + + if node_colors is None: + node_colors = {} + + assert type(node_colors) is dict + + node_attrs = {"shape": "circle", "fontsize": "9", "height": "0.2", "width": "0.2"} + + dot = graphviz.Digraph(format=fmt, node_attr=node_attrs) + + inputs = set() + for k in config.genome_config.input_keys: + inputs.add(k) + name = node_names.get(k, str(k)) + input_attrs = { + "style": "filled", + "shape": "box", + "fillcolor": node_colors.get(k, "lightgray"), + } + dot.node(name, _attributes=input_attrs) + + outputs = set() + for k in config.genome_config.output_keys: + outputs.add(k) + name = node_names.get(k, str(k)) + node_attrs = {"style": "filled", "fillcolor": node_colors.get(k, "lightblue")} + + dot.node(name, _attributes=node_attrs) + + for n in genome.nodes.keys(): + if n in inputs or n in outputs: + continue + + attrs = {"style": "filled", "fillcolor": node_colors.get(n, "white")} + dot.node(str(n), _attributes=attrs) + + for cg in genome.connections.values(): + if cg.enabled or show_disabled: + input, output = cg.key + a = node_names.get(input, str(input)) + b = node_names.get(output, str(output)) + style = "solid" if cg.enabled else "dotted" + color = "green" if cg.weight > 0 else "red" + width = str(0.1 + abs(cg.weight / 5.0)) + dot.edge( + a, b, _attributes={"style": style, "color": color, "penwidth": width} + ) + + dot.render(filename, view=view) + + return dot