feat(ai): add draw_net

This commit is contained in:
Kristofers Solo 2024-01-05 20:16:52 +02:00
parent 80b1c7518e
commit 599f456c46
2 changed files with 80 additions and 34 deletions

View File

@ -9,7 +9,7 @@ from utils import BASE_PATH, CONFIG
from .evaluations import eval_genome from .evaluations import eval_genome
from .io import get_config, save_genome from .io import get_config, save_genome
from .log import log from .log import log
from .visualize import plot_progress, plot_species, plot_stats from .visualize import draw_net, plot_species, plot_stats
def train( def train(
@ -52,6 +52,14 @@ def train(
filename=CONFIG.ai.plot_path / "avg_fitness.png", filename=CONFIG.ai.plot_path / "avg_fitness.png",
) )
plot_species(stats, view=False, filename=CONFIG.ai.plot_path / "speciation.png") plot_species(stats, view=False, filename=CONFIG.ai.plot_path / "speciation.png")
draw_net(config, winner, view=False, filename=CONFIG.ai.plot_path / "network.gv")
draw_net(
config,
winner,
view=False,
filename=CONFIG.ai.plot_path / "network-pruned.gv",
prune_unused=True,
)
log.info("Saving best genome") log.info("Saving best genome")
save_genome(winner) save_genome(winner)

View File

@ -4,8 +4,6 @@ import matplotlib.pyplot as plt
import neat import neat
import numpy as np import numpy as np
from .log import log
def plot_stats( def plot_stats(
statistics: neat.StatisticsReporter, statistics: neat.StatisticsReporter,
@ -14,11 +12,6 @@ def plot_stats(
filename: str | Path = "avg_fitness.svg", filename: str | Path = "avg_fitness.svg",
): ):
"""Plots the population's average and best fitness.""" """Plots the population's average and best fitness."""
if plt is None:
log.warning(
"This display is not available due to a missing optional dependency (matplotlib)"
)
return
generation = range(len(statistics.most_fit_genomes)) generation = range(len(statistics.most_fit_genomes))
best_fitness = [c.fitness for c in statistics.most_fit_genomes] best_fitness = [c.fitness for c in statistics.most_fit_genomes]
@ -51,11 +44,6 @@ def plot_species(
filename: str | Path = "speciation.svg", filename: str | Path = "speciation.svg",
): ):
"""Visualizes speciation throughout evolution.""" """Visualizes speciation throughout evolution."""
if plt is None:
log.warning(
"This display is not available due to a missing optional dependency (matplotlib)"
)
return
species_sizes = statistics.get_species_sizes() species_sizes = statistics.get_species_sizes()
num_generations = len(species_sizes) num_generations = len(species_sizes)
@ -76,28 +64,78 @@ def plot_species(
plt.close() plt.close()
def plot_progress( def draw_net(
generations: list[int], config: neat.Config,
mean_fitness: list[int], genome: neat.DefaultGenome,
max_fitness: list[int],
view: bool = False, view: bool = False,
filename: str | Path = "progress.svg", filename: str | Path = None,
node_names: dict = None,
show_disabled: bool = True,
prune_unused: bool = False,
node_colors: dict = None,
fmt: str = "svg",
): ):
if plt is None: """Receives a genome and draws a neural network with arbitrary topology."""
log.warning(
"This display is not available due to a missing optional dependency (matplotlib)"
)
return
plt.plot(generations, mean_fitness, label="Mean Fitness")
plt.plot(generations, max_fitness, label="Max Fitness")
plt.xlabel("Generations")
plt.ylabel("Fitness")
plt.title("NEAT Algorithm Progress")
plt.legend()
plt.grid(True)
plt.savefig(str(filename))
if view: # If requested, use a copy of the genome which omits all components that won't affect the output.
plt.show() if prune_unused:
if show_disabled:
warnings.warn("show_disabled has no effect when prune_unused is True")
plt.close() genome = genome.get_pruned_copy(config.genome_config)
if node_names is None:
node_names = {}
assert type(node_names) is dict
if node_colors is None:
node_colors = {}
assert type(node_colors) is dict
node_attrs = {"shape": "circle", "fontsize": "9", "height": "0.2", "width": "0.2"}
dot = graphviz.Digraph(format=fmt, node_attr=node_attrs)
inputs = set()
for k in config.genome_config.input_keys:
inputs.add(k)
name = node_names.get(k, str(k))
input_attrs = {
"style": "filled",
"shape": "box",
"fillcolor": node_colors.get(k, "lightgray"),
}
dot.node(name, _attributes=input_attrs)
outputs = set()
for k in config.genome_config.output_keys:
outputs.add(k)
name = node_names.get(k, str(k))
node_attrs = {"style": "filled", "fillcolor": node_colors.get(k, "lightblue")}
dot.node(name, _attributes=node_attrs)
for n in genome.nodes.keys():
if n in inputs or n in outputs:
continue
attrs = {"style": "filled", "fillcolor": node_colors.get(n, "white")}
dot.node(str(n), _attributes=attrs)
for cg in genome.connections.values():
if cg.enabled or show_disabled:
input, output = cg.key
a = node_names.get(input, str(input))
b = node_names.get(output, str(output))
style = "solid" if cg.enabled else "dotted"
color = "green" if cg.weight > 0 else "red"
width = str(0.1 + abs(cg.weight / 5.0))
dot.edge(
a, b, _attributes={"style": style, "color": color, "penwidth": width}
)
dot.render(filename, view=view)
return dot