mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
feat(ai): add draw_net
This commit is contained in:
parent
80b1c7518e
commit
599f456c46
@ -9,7 +9,7 @@ from utils import BASE_PATH, CONFIG
|
|||||||
from .evaluations import eval_genome
|
from .evaluations import eval_genome
|
||||||
from .io import get_config, save_genome
|
from .io import get_config, save_genome
|
||||||
from .log import log
|
from .log import log
|
||||||
from .visualize import plot_progress, plot_species, plot_stats
|
from .visualize import draw_net, plot_species, plot_stats
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@ -52,6 +52,14 @@ def train(
|
|||||||
filename=CONFIG.ai.plot_path / "avg_fitness.png",
|
filename=CONFIG.ai.plot_path / "avg_fitness.png",
|
||||||
)
|
)
|
||||||
plot_species(stats, view=False, filename=CONFIG.ai.plot_path / "speciation.png")
|
plot_species(stats, view=False, filename=CONFIG.ai.plot_path / "speciation.png")
|
||||||
|
draw_net(config, winner, view=False, filename=CONFIG.ai.plot_path / "network.gv")
|
||||||
|
draw_net(
|
||||||
|
config,
|
||||||
|
winner,
|
||||||
|
view=False,
|
||||||
|
filename=CONFIG.ai.plot_path / "network-pruned.gv",
|
||||||
|
prune_unused=True,
|
||||||
|
)
|
||||||
|
|
||||||
log.info("Saving best genome")
|
log.info("Saving best genome")
|
||||||
save_genome(winner)
|
save_genome(winner)
|
||||||
|
|||||||
@ -4,8 +4,6 @@ import matplotlib.pyplot as plt
|
|||||||
import neat
|
import neat
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .log import log
|
|
||||||
|
|
||||||
|
|
||||||
def plot_stats(
|
def plot_stats(
|
||||||
statistics: neat.StatisticsReporter,
|
statistics: neat.StatisticsReporter,
|
||||||
@ -14,11 +12,6 @@ def plot_stats(
|
|||||||
filename: str | Path = "avg_fitness.svg",
|
filename: str | Path = "avg_fitness.svg",
|
||||||
):
|
):
|
||||||
"""Plots the population's average and best fitness."""
|
"""Plots the population's average and best fitness."""
|
||||||
if plt is None:
|
|
||||||
log.warning(
|
|
||||||
"This display is not available due to a missing optional dependency (matplotlib)"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
generation = range(len(statistics.most_fit_genomes))
|
generation = range(len(statistics.most_fit_genomes))
|
||||||
best_fitness = [c.fitness for c in statistics.most_fit_genomes]
|
best_fitness = [c.fitness for c in statistics.most_fit_genomes]
|
||||||
@ -51,11 +44,6 @@ def plot_species(
|
|||||||
filename: str | Path = "speciation.svg",
|
filename: str | Path = "speciation.svg",
|
||||||
):
|
):
|
||||||
"""Visualizes speciation throughout evolution."""
|
"""Visualizes speciation throughout evolution."""
|
||||||
if plt is None:
|
|
||||||
log.warning(
|
|
||||||
"This display is not available due to a missing optional dependency (matplotlib)"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
species_sizes = statistics.get_species_sizes()
|
species_sizes = statistics.get_species_sizes()
|
||||||
num_generations = len(species_sizes)
|
num_generations = len(species_sizes)
|
||||||
@ -76,28 +64,78 @@ def plot_species(
|
|||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
def plot_progress(
|
def draw_net(
|
||||||
generations: list[int],
|
config: neat.Config,
|
||||||
mean_fitness: list[int],
|
genome: neat.DefaultGenome,
|
||||||
max_fitness: list[int],
|
|
||||||
view: bool = False,
|
view: bool = False,
|
||||||
filename: str | Path = "progress.svg",
|
filename: str | Path = None,
|
||||||
|
node_names: dict = None,
|
||||||
|
show_disabled: bool = True,
|
||||||
|
prune_unused: bool = False,
|
||||||
|
node_colors: dict = None,
|
||||||
|
fmt: str = "svg",
|
||||||
):
|
):
|
||||||
if plt is None:
|
"""Receives a genome and draws a neural network with arbitrary topology."""
|
||||||
log.warning(
|
|
||||||
"This display is not available due to a missing optional dependency (matplotlib)"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
plt.plot(generations, mean_fitness, label="Mean Fitness")
|
|
||||||
plt.plot(generations, max_fitness, label="Max Fitness")
|
|
||||||
plt.xlabel("Generations")
|
|
||||||
plt.ylabel("Fitness")
|
|
||||||
plt.title("NEAT Algorithm Progress")
|
|
||||||
plt.legend()
|
|
||||||
plt.grid(True)
|
|
||||||
plt.savefig(str(filename))
|
|
||||||
|
|
||||||
if view:
|
# If requested, use a copy of the genome which omits all components that won't affect the output.
|
||||||
plt.show()
|
if prune_unused:
|
||||||
|
if show_disabled:
|
||||||
|
warnings.warn("show_disabled has no effect when prune_unused is True")
|
||||||
|
|
||||||
plt.close()
|
genome = genome.get_pruned_copy(config.genome_config)
|
||||||
|
|
||||||
|
if node_names is None:
|
||||||
|
node_names = {}
|
||||||
|
|
||||||
|
assert type(node_names) is dict
|
||||||
|
|
||||||
|
if node_colors is None:
|
||||||
|
node_colors = {}
|
||||||
|
|
||||||
|
assert type(node_colors) is dict
|
||||||
|
|
||||||
|
node_attrs = {"shape": "circle", "fontsize": "9", "height": "0.2", "width": "0.2"}
|
||||||
|
|
||||||
|
dot = graphviz.Digraph(format=fmt, node_attr=node_attrs)
|
||||||
|
|
||||||
|
inputs = set()
|
||||||
|
for k in config.genome_config.input_keys:
|
||||||
|
inputs.add(k)
|
||||||
|
name = node_names.get(k, str(k))
|
||||||
|
input_attrs = {
|
||||||
|
"style": "filled",
|
||||||
|
"shape": "box",
|
||||||
|
"fillcolor": node_colors.get(k, "lightgray"),
|
||||||
|
}
|
||||||
|
dot.node(name, _attributes=input_attrs)
|
||||||
|
|
||||||
|
outputs = set()
|
||||||
|
for k in config.genome_config.output_keys:
|
||||||
|
outputs.add(k)
|
||||||
|
name = node_names.get(k, str(k))
|
||||||
|
node_attrs = {"style": "filled", "fillcolor": node_colors.get(k, "lightblue")}
|
||||||
|
|
||||||
|
dot.node(name, _attributes=node_attrs)
|
||||||
|
|
||||||
|
for n in genome.nodes.keys():
|
||||||
|
if n in inputs or n in outputs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
attrs = {"style": "filled", "fillcolor": node_colors.get(n, "white")}
|
||||||
|
dot.node(str(n), _attributes=attrs)
|
||||||
|
|
||||||
|
for cg in genome.connections.values():
|
||||||
|
if cg.enabled or show_disabled:
|
||||||
|
input, output = cg.key
|
||||||
|
a = node_names.get(input, str(input))
|
||||||
|
b = node_names.get(output, str(output))
|
||||||
|
style = "solid" if cg.enabled else "dotted"
|
||||||
|
color = "green" if cg.weight > 0 else "red"
|
||||||
|
width = str(0.1 + abs(cg.weight / 5.0))
|
||||||
|
dot.edge(
|
||||||
|
a, b, _attributes={"style": style, "color": color, "penwidth": width}
|
||||||
|
)
|
||||||
|
|
||||||
|
dot.render(filename, view=view)
|
||||||
|
|
||||||
|
return dot
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user