mirror of
https://github.com/kristoferssolo/2048.git
synced 2025-10-21 15:20:35 +00:00
feat(ai): add ai
This commit is contained in:
parent
97a64b44b6
commit
e03465d2d3
79
config.txt
Normal file
79
config.txt
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
[NEAT]
|
||||||
|
fitness_criterion = max
|
||||||
|
fitness_threshold = 3.9
|
||||||
|
pop_size = 150
|
||||||
|
reset_on_extinction = False
|
||||||
|
|
||||||
|
[DefaultGenome]
|
||||||
|
# node activation options
|
||||||
|
activation_default = sigmoid
|
||||||
|
activation_mutate_rate = 0.0
|
||||||
|
activation_options = sigmoid
|
||||||
|
|
||||||
|
# node aggregation options
|
||||||
|
aggregation_default = sum
|
||||||
|
aggregation_mutate_rate = 0.0
|
||||||
|
aggregation_options = sum
|
||||||
|
|
||||||
|
# node bias options
|
||||||
|
bias_init_mean = 0.0
|
||||||
|
bias_init_stdev = 1.0
|
||||||
|
bias_max_value = 30.0
|
||||||
|
bias_min_value = -30.0
|
||||||
|
bias_mutate_power = 0.5
|
||||||
|
bias_mutate_rate = 0.7
|
||||||
|
bias_replace_rate = 0.1
|
||||||
|
|
||||||
|
# genome compatibility options
|
||||||
|
compatibility_disjoint_coefficient = 1.0
|
||||||
|
compatibility_weight_coefficient = 0.5
|
||||||
|
|
||||||
|
# connection add/remove rates
|
||||||
|
conn_add_prob = 0.5
|
||||||
|
conn_delete_prob = 0.5
|
||||||
|
|
||||||
|
# connection enable options
|
||||||
|
enabled_default = True
|
||||||
|
enabled_mutate_rate = 0.01
|
||||||
|
|
||||||
|
feed_forward = True
|
||||||
|
initial_connection = full
|
||||||
|
|
||||||
|
# node add/remove rates
|
||||||
|
node_add_prob = 0.2
|
||||||
|
node_delete_prob = 0.2
|
||||||
|
|
||||||
|
# network parameters
|
||||||
|
num_hidden = 0
|
||||||
|
num_inputs = 17
|
||||||
|
num_outputs = 4
|
||||||
|
|
||||||
|
# node response options
|
||||||
|
response_init_mean = 1.0
|
||||||
|
response_init_stdev = 0.0
|
||||||
|
response_max_value = 30.0
|
||||||
|
response_min_value = -30.0
|
||||||
|
response_mutate_power = 0.0
|
||||||
|
response_mutate_rate = 0.0
|
||||||
|
response_replace_rate = 0.0
|
||||||
|
|
||||||
|
# connection weight options
|
||||||
|
weight_init_mean = 0.0
|
||||||
|
weight_init_stdev = 1.0
|
||||||
|
weight_max_value = 30
|
||||||
|
weight_min_value = -30
|
||||||
|
weight_mutate_power = 0.5
|
||||||
|
weight_mutate_rate = 0.8
|
||||||
|
weight_replace_rate = 0.1
|
||||||
|
|
||||||
|
[DefaultSpeciesSet]
|
||||||
|
compatibility_threshold = 3.0
|
||||||
|
|
||||||
|
[DefaultStagnation]
|
||||||
|
species_fitness_func = max
|
||||||
|
max_stagnation = 20
|
||||||
|
species_elitism = 2
|
||||||
|
|
||||||
|
[DefaultReproduction]
|
||||||
|
elitism = 2
|
||||||
|
survival_threshold = 0.2
|
||||||
4
main.py
4
main.py
@ -1,13 +1,15 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
|
||||||
|
from ai import train
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from py2048 import Menu
|
from py2048 import Menu
|
||||||
|
|
||||||
|
|
||||||
@logger.catch
|
@logger.catch
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
Menu().run()
|
# Menu().run()
|
||||||
|
train()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -6,7 +6,12 @@ authors = [{ name = "Kristofers Solo", email = "dev@kristofers.xyz" }]
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
license = { text = "GPLv3" }
|
license = { text = "GPLv3" }
|
||||||
dependencies = ["pygame-ce==2.3.2", "loguru==0.7.2", "attrs==23.1.0"]
|
dependencies = [
|
||||||
|
"pygame-ce==2.3.2",
|
||||||
|
"loguru==0.7.2",
|
||||||
|
"attrs==23.1.0",
|
||||||
|
"neat-python>=0.92",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
pygame-ce>=2.3.2
|
pygame-ce>=2.3.2
|
||||||
loguru>=0.7.2
|
loguru>=0.7.2
|
||||||
attrs>=23.1.0
|
attrs>=23.1.0
|
||||||
|
neat-python>=0.92
|
||||||
.
|
.
|
||||||
|
|||||||
3
src/ai/__init__.py
Normal file
3
src/ai/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .train import train
|
||||||
|
|
||||||
|
__all__ = ["train"]
|
||||||
73
src/ai/train.py
Normal file
73
src/ai/train.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import neat
|
||||||
|
from loguru import logger
|
||||||
|
from path import BASE_PATH
|
||||||
|
from py2048 import Menu
|
||||||
|
|
||||||
|
|
||||||
|
def _get_config() -> neat.Config:
|
||||||
|
config_path = BASE_PATH / "config.txt"
|
||||||
|
return neat.Config(
|
||||||
|
neat.DefaultGenome,
|
||||||
|
neat.DefaultReproduction,
|
||||||
|
neat.DefaultSpeciesSet,
|
||||||
|
neat.DefaultStagnation,
|
||||||
|
config_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train() -> None:
|
||||||
|
config = _get_config()
|
||||||
|
# p = neat.Checkpointer.restore_checkpoint("neat-checkpoint-0")
|
||||||
|
p = neat.Population(config)
|
||||||
|
p.add_reporter(neat.StdOutReporter(True))
|
||||||
|
stats = neat.StatisticsReporter()
|
||||||
|
p.add_reporter(stats)
|
||||||
|
p.add_reporter(neat.Checkpointer(1))
|
||||||
|
|
||||||
|
winner = p.run(eval_genomes, 50)
|
||||||
|
|
||||||
|
logger.info(f"\nBest genome:\n{winner}")
|
||||||
|
|
||||||
|
|
||||||
|
def eval_genomes(genomes, config: neat.Config):
|
||||||
|
for genome_id, genome in genomes:
|
||||||
|
genome.fitness = 4.0
|
||||||
|
app = Menu()
|
||||||
|
net = neat.nn.FeedForwardNetwork.create(genome, config)
|
||||||
|
|
||||||
|
app.play()
|
||||||
|
app._game_active = False
|
||||||
|
|
||||||
|
while True:
|
||||||
|
output = net.activate(
|
||||||
|
(
|
||||||
|
*app.game.board.matrix(),
|
||||||
|
app.game.board.score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
decision = output.index(max(output))
|
||||||
|
|
||||||
|
decisions = {
|
||||||
|
0: app.game.move_up,
|
||||||
|
1: app.game.move_down,
|
||||||
|
2: app.game.move_left,
|
||||||
|
3: app.game.move_right,
|
||||||
|
}
|
||||||
|
|
||||||
|
decisions[decision]()
|
||||||
|
|
||||||
|
app._hande_events()
|
||||||
|
app.game.draw(app._surface)
|
||||||
|
|
||||||
|
if app.game.board._is_full() or app.game.board.score > 10_000:
|
||||||
|
calculate_fitness(genome, app.game.board.score)
|
||||||
|
logger.info(
|
||||||
|
f"Genome: {genome_id} fitness: {genome.fitness} score: {app.game.board.score}"
|
||||||
|
)
|
||||||
|
app.game.restart()
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_fitness(genome, score: int):
|
||||||
|
genome.fitness += score
|
||||||
@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pygame
|
import pygame
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -114,3 +115,25 @@ class Board(pygame.sprite.Group):
|
|||||||
"""Reset the board."""
|
"""Reset the board."""
|
||||||
self.empty()
|
self.empty()
|
||||||
self._initiate_game()
|
self._initiate_game()
|
||||||
|
|
||||||
|
def get_tile(self, position: Position) -> Optional[Tile]:
|
||||||
|
"""Return the tile at the specified position."""
|
||||||
|
tile: Tile
|
||||||
|
for tile in self.sprites():
|
||||||
|
if tile.pos == position:
|
||||||
|
return tile
|
||||||
|
return None
|
||||||
|
|
||||||
|
def matrix(self) -> list[int]:
|
||||||
|
"""Return a 1d matrix of values of the tiles."""
|
||||||
|
matrix: list[int] = []
|
||||||
|
|
||||||
|
for i in range(1, Config.BOARD.len + 1):
|
||||||
|
for j in range(1, Config.BOARD.len + 1):
|
||||||
|
tile = self.get_tile(Position(j, i))
|
||||||
|
if tile:
|
||||||
|
matrix.append(tile.value)
|
||||||
|
else:
|
||||||
|
matrix.append(0)
|
||||||
|
|
||||||
|
return matrix
|
||||||
|
|||||||
@ -10,9 +10,11 @@ from py2048.utils import ColorScheme, Direction, Position, Size
|
|||||||
from .abc import MovableUIElement, UIElement
|
from .abc import MovableUIElement, UIElement
|
||||||
|
|
||||||
|
|
||||||
def _grid_pos(pos: int) -> int:
|
def _grid_pos(position: Position) -> Position:
|
||||||
"""Return the position in the grid."""
|
"""Return the position in the grid."""
|
||||||
return pos // Config.TILE.size + 1
|
x = (position.x - Config.BOARD.pos.x) // Config.TILE.size + 1
|
||||||
|
y = (position.y - Config.BOARD.pos.y) // Config.TILE.size + 1
|
||||||
|
return Position(x, y)
|
||||||
|
|
||||||
|
|
||||||
class Tile(MovableUIElement, pygame.sprite.Sprite):
|
class Tile(MovableUIElement, pygame.sprite.Sprite):
|
||||||
@ -191,9 +193,9 @@ class Tile(MovableUIElement, pygame.sprite.Sprite):
|
|||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
"""Return a hash of the tile."""
|
"""Return a hash of the tile."""
|
||||||
return hash((self.rect.x, self.rect.y, self.value))
|
return hash((self.rect.x, self.rect.y))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pos(self) -> Position:
|
def pos(self) -> Position:
|
||||||
"""Return the position of the tile."""
|
"""Return the position of the tile."""
|
||||||
return Position(_grid_pos(self.rect.x), _grid_pos(self.rect.y))
|
return _grid_pos(Position(*self.rect.topleft))
|
||||||
|
|||||||
@ -49,10 +49,6 @@ class Game:
|
|||||||
logger.info("Game over!")
|
logger.info("Game over!")
|
||||||
self.restart()
|
self.restart()
|
||||||
|
|
||||||
def restart(self) -> None:
|
|
||||||
self.board.reset()
|
|
||||||
self.update_score(0)
|
|
||||||
|
|
||||||
def move_up(self) -> None:
|
def move_up(self) -> None:
|
||||||
self.move(Direction.UP)
|
self.move(Direction.UP)
|
||||||
|
|
||||||
@ -65,6 +61,11 @@ class Game:
|
|||||||
def move_right(self) -> None:
|
def move_right(self) -> None:
|
||||||
self.move(Direction.RIGHT)
|
self.move(Direction.RIGHT)
|
||||||
|
|
||||||
|
def restart(self) -> None:
|
||||||
|
self.board.reset()
|
||||||
|
self.board.score = 0
|
||||||
|
self.update_score(0)
|
||||||
|
|
||||||
def _create_labels(self) -> pygame.sprite.Group:
|
def _create_labels(self) -> pygame.sprite.Group:
|
||||||
size = Size(60, 40)
|
size = Size(60, 40)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import neat
|
||||||
import pygame
|
import pygame
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@ -61,6 +62,7 @@ class Menu:
|
|||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""Run the game loop."""
|
"""Run the game loop."""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
self._hande_events()
|
self._hande_events()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user