From 73548ed8f43d81f37a4f52768425d662490e2503 Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Wed, 3 Jan 2024 04:15:17 +0200 Subject: [PATCH] refactor(AI): separate into files --- best_genome | Bin 0 -> 2923 bytes best_genome.pkl | Bin 0 -> 2923 bytes main.py | 2 +- src/ai/__init__.py | 5 ++- src/ai/config.py | 13 +++++++ src/ai/evaluation.py | 46 +++++++++++++++++++++++ src/ai/io.py | 15 ++++++++ src/ai/train.py | 73 ------------------------------------ src/ai/training.py | 22 +++++++++++ src/py2048/objects/board.py | 5 +++ 10 files changed, 105 insertions(+), 76 deletions(-) create mode 100644 best_genome create mode 100644 best_genome.pkl create mode 100644 src/ai/config.py create mode 100644 src/ai/evaluation.py create mode 100644 src/ai/io.py delete mode 100644 src/ai/train.py create mode 100644 src/ai/training.py diff --git a/best_genome b/best_genome new file mode 100644 index 0000000000000000000000000000000000000000..90c5458b05b3a63b02d179870de9621c87b36160 GIT binary patch literal 2923 zcmZwJe^AqP9LI5BDmH#~aBz@FNtvKYM^3o9_)$P4GayZ?lT!wK*>{a?9AhXK2paOE zzHp{r(&W_2_~C_RT$r6_i=2|8CU9_wb$}p}0?tho0;KQz`~B*D-hJgS7!O{b=eXVW zDV|rI>f*qEF)n&al^dd_^adTpGH$U{rplC+8!!C4mkr**7P3K%OiSgnNA4z%-Uz=)nB%^b4y7z3f3!mxS;kwOiFvHKL5h7eN6)7hG7Gj+8a18E#z)O1#g;vn z-^4gmdR1B$mBE&q(u85hQkHfR61cy4ns$$m8u+xQZei$l+T$^DLdXk{?VQ;+`1Bb9 z$q@1ZWba(%?q547poik~Add3`^mk2mX<1?*p^JU)tRoO^dKS>_BWoiq z!*&AUra%t;CfhMQS8;KtfM%slOJV6=(AsJ5tIG+5n}Ps!6fAe#6S9mzxM>B4qE2LP z+j9TvUIO8!5LlWS4{PvnKQ5pdX;UaHjsCnXvC@`IAl#&YrI9n<5tnn~351)%0qs=Q zWEsM45(qcF2FFe1EQvWAI4huOX;TEC>v^xPKIUd25N?Ww;~rKWRW(`PBM@$ifu&|j zF{m}@2!xy7f~DgDn$lm+6$$8HY10}&H?mR=u#p=GgqzmFaW_)@OWtsNfk3z^36Ra+ z)Ds}DClGF02dMwn{=sDr3j{PJZQ1}$lZBqISLi4L;ieA&%}l3Qv}4~A2sdql<3`_( zn?BsLo zibj$qCSwSMn?8c$Y(9sUOcc)u=z+9J$D>7G3{)l@A`ouMhT|qTdS;hcRuBj`83EaP zUQ8J~w}(KuDG!djUB0F($!n8(nELSD1j0=Pa9q3ho3~i4_=dqR;ie*3 zx~q(|D>j;mCETw#RB#}aK)9(9P^-4l6m~w9K)C4# z4w*YnT5J7Qh|i8c4mULc>TeHz-f9;=sP9Ufn&CLhN#{Qk%r@aT9^s~wfJQVQ+upBw znLxPd6yH>RZMmmid}Z=Wxan6|>Rhz?cte=@USO3rwZPKh3$8W0=RAbtc!ZlQfQF7O z{wX?GOCa3T%Ash-Ri$NdFNyzgehD{SfTahUR605}QCPYoZMp=gD^{DKXxdC5+;jzy z^@`W`%t|)`;igVN&2h5%4J{D_!cEX}2T{H7kaMffPXQM`~ zGiWkc#zm!8=TK_#ahcJiV@+vFNBXZv*F^67KJK%B?3qAw@-Djn@gO}wU#AD@Ax=9H t+6nZ4);>D^&YdS~8KGsIHn&w2=hj79pQufYbEYO&PZ^D@GTM|D@*f@76e|D# literal 0 HcmV?d00001 diff --git a/best_genome.pkl b/best_genome.pkl new file mode 100644 index 0000000000000000000000000000000000000000..83a77cc98ec43d8a55c0fdf7855db6491bf72d0c GIT binary patch literal 2923 zcmZwJdrVVj7{_s>Vo}B=OdLA531JE9ybvIoZP+e4xfC$jWNuEKtsHvFrA698P!L(1 zGCYz2$B>YKBsd6$Y-(g1D0Cso19Uh_L9x`y&eL-w@3Uk7g7V?_exYgl z9$k2Pznhce!MJHDrD46A(&|zumT})kB`A%_hH(D-Vs>>do6QC?k|ZjVRea019nfjD z)B%G=r`5v`qUN}V!s$56EP1&@=~>1{nAtk7xj~A3IY+zDuFRqgN~5L?Y`AR3){x)l zRl!Iptx}at#j|-v6+gV-I7_?n3Eb~^(H`NZk*?$Qv3Yvha~?S(^g1Bx!Q~;Be)1rY z1fgYs%!e$Gjko7`^g{SMh~t)XNH&lhb-l!E3xROcN?3YyCVfeEes8g671Ex zy2&Nn#(F`d8faE}-6ZHwq)$y9k7v0s(agpJlqH4igABt%auh0XLe9 zTh9{+H?0SBuhwrSDW!-<&&5qaaGZHXfJRnaKp@-{0!wul{o6On)({9cZGxrgnCw*^ zsXYY3P4C0fcwIMJ@O*|xc5zb}Ed9|s-I2M!kwCaf4oh=6@0IV}+d?4Rv=x@>?Qt1F zZj}VWP1|8<)HR^(QQQEJX2ngBu=HKgL9aJ!djueRm0jn|8p`!$3*f zU{nKvaMLHSG`2jcY|+F&JhF+K;sAAKT*>K33?UG1`W#Sho#BYjNC|;()0faR+*X)X zfAj|e;ii4CG+;}~^>?|&qi5nKC7|J+5BmPJM-T`%#lvxS6~p-GMhJwP)PQ=Y)&#p$ z{6-+$ln7{`{k3D~%CGQfM%ifu)D_(uSsMHwlEB@?dGGXEMgCKbki?5ZPAP{aU0d%kQ zbV=C`3xRM`DThoqcg}1Ny}Fi1Q{tvFKwXy5JL9W_|C0mZrV2Q&H=x?dXS47|a3I`N z320=gWaRO-D&jcYR0ZhSF3$;LWD<|8;-(sCvZjXU{2l}m2shO_nnw1Nd*}NR2six* zsBBLA;{J!i2aw}9+;kDp@JA=@zdH-B-K4mw4$#nNjbrk_b^LJ-gqwZ=WKE5$iYe_R z5N^5*sDJ!&MQhP0fpAk3Aal_5qOdnIcw`YbH9L-rO{i#csU{F^Y5~+?oAds7sqo3^ zSi()M95Rh9Hf_$?)=VtnrZzwwmBXg<9m2~pA#S<_$Q}i|4L>+?`P2L@#DMg&SZj z_w4`fehMdMp-Q9Fb2mS`BvY3-W@B8_D19n-rKGs$EW*WfiITf!rYqs9$ud&CMxCP5 z#IuZ>Qmsy-)WY);y)lI~suT<8-j~-z?rU?nX>e)uR^MV1Jup8=57H0lA$pk8&V+Ub pU2`h5df7_(tF?sC63|)R;U-^q)ho4;ktS#iT1u~H6>_6${eRDJ8D; None: # Menu().run() - train() + train(100) if __name__ == "__main__": diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 4102b72..38f89c2 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -1,3 +1,4 @@ -from .train import train +from .io import read_genome +from .training import train -__all__ = ["train"] +__all__ = ["train", "read_genome"] diff --git a/src/ai/config.py b/src/ai/config.py new file mode 100644 index 0000000..c111479 --- /dev/null +++ b/src/ai/config.py @@ -0,0 +1,13 @@ +import neat +from path import BASE_PATH + + +def get_config() -> neat.Config: + config_path = BASE_PATH / "config.txt" + return neat.Config( + neat.DefaultGenome, + neat.DefaultReproduction, + neat.DefaultSpeciesSet, + neat.DefaultStagnation, + config_path, + ) diff --git a/src/ai/evaluation.py b/src/ai/evaluation.py new file mode 100644 index 0000000..1efec3a --- /dev/null +++ b/src/ai/evaluation.py @@ -0,0 +1,46 @@ +import neat +from loguru import logger +from py2048 import Menu + + +def eval_genomes(genomes, config: neat.Config): + for genome_id, genome in genomes: + genome.fitness = 0 + app = Menu() + net = neat.nn.FeedForwardNetwork.create(genome, config) + + app.play() + app._game_active = False + + while True: + output = net.activate( + ( + *app.game.board.matrix(), + app.game.board.score, + ) + ) + + decision = output.index(max(output)) + + decisions = { + 0: app.game.move_up, + 1: app.game.move_down, + 2: app.game.move_left, + 3: app.game.move_right, + } + + decisions[decision]() + + app._hande_events() + app.game.draw(app._surface) + max_val = app.game.board.max_val() + + if app.game.board._is_full() or max_val >= 2048: + calculate_fitness(genome, max_val) + logger.info(f"{max_val=}") + app.game.restart() + break + + +def calculate_fitness(genome: neat.DefaultGenome, score: int): + genome.fitness += score diff --git a/src/ai/io.py b/src/ai/io.py new file mode 100644 index 0000000..e858dde --- /dev/null +++ b/src/ai/io.py @@ -0,0 +1,15 @@ +import pickle +from pathlib import Path + +import neat +from path import BASE_PATH + + +def read_genome(filename: Path) -> neat.DefaultGenome: + with open(filename, "rb") as f: + return pickle.load(f) + + +def save_genome(genome, filename: Path) -> None: + with open(filename, "wb") as f: + pickle.dump(genome, f) diff --git a/src/ai/train.py b/src/ai/train.py deleted file mode 100644 index 861e7cb..0000000 --- a/src/ai/train.py +++ /dev/null @@ -1,73 +0,0 @@ -import neat -from loguru import logger -from path import BASE_PATH -from py2048 import Menu - - -def _get_config() -> neat.Config: - config_path = BASE_PATH / "config.txt" - return neat.Config( - neat.DefaultGenome, - neat.DefaultReproduction, - neat.DefaultSpeciesSet, - neat.DefaultStagnation, - config_path, - ) - - -def train() -> None: - config = _get_config() - # p = neat.Checkpointer.restore_checkpoint("neat-checkpoint-0") - p = neat.Population(config) - p.add_reporter(neat.StdOutReporter(True)) - stats = neat.StatisticsReporter() - p.add_reporter(stats) - p.add_reporter(neat.Checkpointer(1)) - - winner = p.run(eval_genomes, 50) - - logger.info(f"\nBest genome:\n{winner}") - - -def eval_genomes(genomes, config: neat.Config): - for genome_id, genome in genomes: - genome.fitness = 4.0 - app = Menu() - net = neat.nn.FeedForwardNetwork.create(genome, config) - - app.play() - app._game_active = False - - while True: - output = net.activate( - ( - *app.game.board.matrix(), - app.game.board.score, - ) - ) - - decision = output.index(max(output)) - - decisions = { - 0: app.game.move_up, - 1: app.game.move_down, - 2: app.game.move_left, - 3: app.game.move_right, - } - - decisions[decision]() - - app._hande_events() - app.game.draw(app._surface) - - if app.game.board._is_full() or app.game.board.score > 10_000: - calculate_fitness(genome, app.game.board.score) - logger.info( - f"Genome: {genome_id} fitness: {genome.fitness} score: {app.game.board.score}" - ) - app.game.restart() - break - - -def calculate_fitness(genome, score: int): - genome.fitness += score diff --git a/src/ai/training.py b/src/ai/training.py new file mode 100644 index 0000000..47ee69b --- /dev/null +++ b/src/ai/training.py @@ -0,0 +1,22 @@ +import neat +from loguru import logger +from path import BASE_PATH + +from .config import get_config +from .evaluation import eval_genomes +from .io import save_genome + + +def train(generations: int) -> None: + """Train the AI for a given number of generations.""" + config = get_config() + population = neat.Population(config) + population.add_reporter(neat.StdOutReporter(True)) + stats = neat.StatisticsReporter() + population.add_reporter(stats) + population.add_reporter(neat.Checkpointer(1)) + + winner = population.run(eval_genomes, generations) + + logger.info(winner) + save_genome(winner, BASE_PATH / "best_genome") diff --git a/src/py2048/objects/board.py b/src/py2048/objects/board.py index 5fa6a6c..e136dd3 100644 --- a/src/py2048/objects/board.py +++ b/src/py2048/objects/board.py @@ -116,6 +116,11 @@ class Board(pygame.sprite.Group): self.empty() self._initiate_game() + def max_val(self) -> int: + """Return the maximum value of the tiles.""" + tile: Tile + return int(max(tile.value for tile in self.sprites())) + def get_tile(self, position: Position) -> Optional[Tile]: """Return the tile at the specified position.""" tile: Tile