mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
feat(ai): add draw_net
This commit is contained in:
parent
80b1c7518e
commit
599f456c46
@ -9,7 +9,7 @@ from utils import BASE_PATH, CONFIG
|
||||
from .evaluations import eval_genome
|
||||
from .io import get_config, save_genome
|
||||
from .log import log
|
||||
from .visualize import plot_progress, plot_species, plot_stats
|
||||
from .visualize import draw_net, plot_species, plot_stats
|
||||
|
||||
|
||||
def train(
|
||||
@ -52,6 +52,14 @@ def train(
|
||||
filename=CONFIG.ai.plot_path / "avg_fitness.png",
|
||||
)
|
||||
plot_species(stats, view=False, filename=CONFIG.ai.plot_path / "speciation.png")
|
||||
draw_net(config, winner, view=False, filename=CONFIG.ai.plot_path / "network.gv")
|
||||
draw_net(
|
||||
config,
|
||||
winner,
|
||||
view=False,
|
||||
filename=CONFIG.ai.plot_path / "network-pruned.gv",
|
||||
prune_unused=True,
|
||||
)
|
||||
|
||||
log.info("Saving best genome")
|
||||
save_genome(winner)
|
||||
|
||||
@ -4,8 +4,6 @@ import matplotlib.pyplot as plt
|
||||
import neat
|
||||
import numpy as np
|
||||
|
||||
from .log import log
|
||||
|
||||
|
||||
def plot_stats(
|
||||
statistics: neat.StatisticsReporter,
|
||||
@ -14,11 +12,6 @@ def plot_stats(
|
||||
filename: str | Path = "avg_fitness.svg",
|
||||
):
|
||||
"""Plots the population's average and best fitness."""
|
||||
if plt is None:
|
||||
log.warning(
|
||||
"This display is not available due to a missing optional dependency (matplotlib)"
|
||||
)
|
||||
return
|
||||
|
||||
generation = range(len(statistics.most_fit_genomes))
|
||||
best_fitness = [c.fitness for c in statistics.most_fit_genomes]
|
||||
@ -51,11 +44,6 @@ def plot_species(
|
||||
filename: str | Path = "speciation.svg",
|
||||
):
|
||||
"""Visualizes speciation throughout evolution."""
|
||||
if plt is None:
|
||||
log.warning(
|
||||
"This display is not available due to a missing optional dependency (matplotlib)"
|
||||
)
|
||||
return
|
||||
|
||||
species_sizes = statistics.get_species_sizes()
|
||||
num_generations = len(species_sizes)
|
||||
@ -76,28 +64,78 @@ def plot_species(
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_progress(
|
||||
generations: list[int],
|
||||
mean_fitness: list[int],
|
||||
max_fitness: list[int],
|
||||
def draw_net(
|
||||
config: neat.Config,
|
||||
genome: neat.DefaultGenome,
|
||||
view: bool = False,
|
||||
filename: str | Path = "progress.svg",
|
||||
filename: str | Path = None,
|
||||
node_names: dict = None,
|
||||
show_disabled: bool = True,
|
||||
prune_unused: bool = False,
|
||||
node_colors: dict = None,
|
||||
fmt: str = "svg",
|
||||
):
|
||||
if plt is None:
|
||||
log.warning(
|
||||
"This display is not available due to a missing optional dependency (matplotlib)"
|
||||
)
|
||||
return
|
||||
plt.plot(generations, mean_fitness, label="Mean Fitness")
|
||||
plt.plot(generations, max_fitness, label="Max Fitness")
|
||||
plt.xlabel("Generations")
|
||||
plt.ylabel("Fitness")
|
||||
plt.title("NEAT Algorithm Progress")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig(str(filename))
|
||||
"""Receives a genome and draws a neural network with arbitrary topology."""
|
||||
|
||||
if view:
|
||||
plt.show()
|
||||
# If requested, use a copy of the genome which omits all components that won't affect the output.
|
||||
if prune_unused:
|
||||
if show_disabled:
|
||||
warnings.warn("show_disabled has no effect when prune_unused is True")
|
||||
|
||||
plt.close()
|
||||
genome = genome.get_pruned_copy(config.genome_config)
|
||||
|
||||
if node_names is None:
|
||||
node_names = {}
|
||||
|
||||
assert type(node_names) is dict
|
||||
|
||||
if node_colors is None:
|
||||
node_colors = {}
|
||||
|
||||
assert type(node_colors) is dict
|
||||
|
||||
node_attrs = {"shape": "circle", "fontsize": "9", "height": "0.2", "width": "0.2"}
|
||||
|
||||
dot = graphviz.Digraph(format=fmt, node_attr=node_attrs)
|
||||
|
||||
inputs = set()
|
||||
for k in config.genome_config.input_keys:
|
||||
inputs.add(k)
|
||||
name = node_names.get(k, str(k))
|
||||
input_attrs = {
|
||||
"style": "filled",
|
||||
"shape": "box",
|
||||
"fillcolor": node_colors.get(k, "lightgray"),
|
||||
}
|
||||
dot.node(name, _attributes=input_attrs)
|
||||
|
||||
outputs = set()
|
||||
for k in config.genome_config.output_keys:
|
||||
outputs.add(k)
|
||||
name = node_names.get(k, str(k))
|
||||
node_attrs = {"style": "filled", "fillcolor": node_colors.get(k, "lightblue")}
|
||||
|
||||
dot.node(name, _attributes=node_attrs)
|
||||
|
||||
for n in genome.nodes.keys():
|
||||
if n in inputs or n in outputs:
|
||||
continue
|
||||
|
||||
attrs = {"style": "filled", "fillcolor": node_colors.get(n, "white")}
|
||||
dot.node(str(n), _attributes=attrs)
|
||||
|
||||
for cg in genome.connections.values():
|
||||
if cg.enabled or show_disabled:
|
||||
input, output = cg.key
|
||||
a = node_names.get(input, str(input))
|
||||
b = node_names.get(output, str(output))
|
||||
style = "solid" if cg.enabled else "dotted"
|
||||
color = "green" if cg.weight > 0 else "red"
|
||||
width = str(0.1 + abs(cg.weight / 5.0))
|
||||
dot.edge(
|
||||
a, b, _attributes={"style": style, "color": color, "penwidth": width}
|
||||
)
|
||||
|
||||
dot.render(filename, view=view)
|
||||
|
||||
return dot
|
||||
|
||||
Loading…
Reference in New Issue
Block a user