mirror of
https://github.com/kristoferssolo/Tetris.git
synced 2025-10-21 20:00:35 +00:00
feat(ai): add plot
This commit is contained in:
parent
64f14d178f
commit
845f2bd024
103
src/ai/visualize.py
Normal file
103
src/ai/visualize.py
Normal file
@ -0,0 +1,103 @@
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import neat
|
||||
import numpy as np
|
||||
|
||||
from .log import log
|
||||
|
||||
|
||||
def plot_stats(
|
||||
statistics: neat.StatisticsReporter,
|
||||
ylog: bool = False,
|
||||
view: bool = False,
|
||||
filename: str | Path = "avg_fitness.svg",
|
||||
):
|
||||
"""Plots the population's average and best fitness."""
|
||||
if plt is None:
|
||||
log.warning(
|
||||
"This display is not available due to a missing optional dependency (matplotlib)"
|
||||
)
|
||||
return
|
||||
|
||||
generation = range(len(statistics.most_fit_genomes))
|
||||
best_fitness = [c.fitness for c in statistics.most_fit_genomes]
|
||||
avg_fitness = np.array(statistics.get_fitness_mean())
|
||||
stdev_fitness = np.array(statistics.get_fitness_stdev())
|
||||
|
||||
plt.plot(generation, avg_fitness, "b-", label="average")
|
||||
plt.plot(generation, avg_fitness - stdev_fitness, "g-.", label="-1 sd")
|
||||
plt.plot(generation, avg_fitness + stdev_fitness, "g-.", label="+1 sd")
|
||||
plt.plot(generation, best_fitness, "r-", label="best")
|
||||
|
||||
plt.title("Population's average and best fitness")
|
||||
plt.xlabel("Generations")
|
||||
plt.ylabel("Fitness")
|
||||
plt.grid()
|
||||
plt.legend(loc="best")
|
||||
if ylog:
|
||||
plt.gca().set_yscale("symlog")
|
||||
|
||||
plt.savefig(str(filename))
|
||||
if view:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_species(
|
||||
statistics: neat.StatisticsReporter,
|
||||
view: bool = False,
|
||||
filename: str | Path = "speciation.svg",
|
||||
):
|
||||
"""Visualizes speciation throughout evolution."""
|
||||
if plt is None:
|
||||
log.warning(
|
||||
"This display is not available due to a missing optional dependency (matplotlib)"
|
||||
)
|
||||
return
|
||||
|
||||
species_sizes = statistics.get_species_sizes()
|
||||
num_generations = len(species_sizes)
|
||||
curves = np.array(species_sizes).T
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.stackplot(range(num_generations), *curves)
|
||||
|
||||
plt.title("Speciation")
|
||||
plt.ylabel("Size per Species")
|
||||
plt.xlabel("Generations")
|
||||
|
||||
plt.savefig(str(filename))
|
||||
|
||||
if view:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_progress(
|
||||
generations: list[int],
|
||||
mean_fitness: list[int],
|
||||
max_fitness: list[int],
|
||||
view: bool = False,
|
||||
filename: str | Path = "progress.svg",
|
||||
):
|
||||
if plt is None:
|
||||
log.warning(
|
||||
"This display is not available due to a missing optional dependency (matplotlib)"
|
||||
)
|
||||
return
|
||||
plt.plot(generations, mean_fitness, label="Mean Fitness")
|
||||
plt.plot(generations, max_fitness, label="Max Fitness")
|
||||
plt.xlabel("Generations")
|
||||
plt.ylabel("Fitness")
|
||||
plt.title("NEAT Algorithm Progress")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig(str(filename))
|
||||
|
||||
if view:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
Loading…
Reference in New Issue
Block a user