add: data generator

This commit is contained in:
h z
2025-01-29 21:24:10 +00:00
parent 34e9b95145
commit 05d375d521
5 changed files with 101 additions and 7 deletions

View File

@@ -15,6 +15,10 @@ class Board:
self.lives.add((x, y))
def evaluate(self):
new_lives = self.try_evaluate()
self.lives = new_lives
def try_evaluate(self) -> set[(int, int)]:
new_lives = set()
for (x, y) in itertools.product(range(self.width), range(self.height)):
neighbor_count = 0
@@ -33,5 +37,18 @@ class Board:
new_lives.add((x, y))
if (x, y) not in self.lives and neighbor_count == 3:
new_lives.add((x, y))
self.lives = new_lives
return new_lives
def __str__(self):
res = ""
for y in range(self.height):
for x in range(self.width):
if (x, y) in self.lives:
res += "o"
else:
res += " "
res += "\n"
return res
def __repr__(self):
return self.__str__()

View File

@@ -0,0 +1,56 @@
import random
from Board import Board
from DataGenerator import DataGenerator
def _circle_generator(cx, cy, px, py):
yield cx + px, cy + py
yield cx - px, cy + py
yield cx + px, cy - py
yield cx - px, cy - py
yield cx + py, cy + px
yield cx - py, cy + px
yield cx + py, cy - px
yield cx - py, cy - px
def circle_generator(cx, cy, radius):
x = 0
y = radius
d = 1 - radius
for tx, ty in _circle_generator(cx, cy, x, y):
yield tx, ty
while x < y:
x += 1
if d < 0:
d += 2 * x + 1
else:
y -= 1
d += 2 * (x - y) + 1
for tx, ty in _circle_generator(cx, cy, x, y):
yield tx, ty
class CircleGenerator(DataGenerator):
def generate(self, amount: int) -> list[Board]:
res = []
generated = set()
for i in range(amount):
success = False
while not success:
success = True
b = Board(self.width, self.height, self.qx, self.qy)
cx, cy, radius = self.random_config()
while (cx, cy, radius) in generated:
cx, cy, radius = self.random_config()
for tx, ty in circle_generator(cx, cy, radius):
b.lives.add((tx, ty))
if len(b.try_evaluate()) == 0:
success = False
continue
res.append(b)
return res
def random_config(self):
cx = random.randint(min(5, self.width), self.width - 1)
cy = random.randint(min(5, self.width), self.height - 1)
radius = random.randint(2, min(self.width, self.height))
return cx, cy, radius

17
DataGenerator/__init__.py Normal file
View File

@@ -0,0 +1,17 @@
from abc import ABC, abstractmethod
from Board import Board
class DataGenerator:
def __init__(self, width, height, qx, qy):
self.width = width
self.height = height
self.qx = qx
self.qy = qy
@abstractmethod
def generate(self, amount: int) -> list[Board]:
pass

View File

@@ -1,7 +1,7 @@
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import regularizers
from Board import Board
import random
import os
@@ -29,8 +29,10 @@ class NeuralSolver:
def _build_reverse_model(self):
inputs = keras.Input(shape=(self.Width, self.Height, 1), name="FinalState")
hidden = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(inputs)
hidden = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(hidden)
hidden = keras.layers.Conv2D(32, 3, padding="same", activation="relu",kernel_regularizer=regularizers.l2(1e-4))(inputs)
hidden = keras.layers.BatchNormalization()(hidden)
hidden = keras.layers.Conv2D(32, 5, padding="same", activation="relu",kernel_regularizer=regularizers.l2(1e-4))(hidden)
hidden = keras.layers.Conv2D(32, 3, padding="same", activation="relu",kernel_regularizer=regularizers.l2(1e-4))(hidden)
outputs = keras.layers.Conv2D(1, 1, padding="same", activation="sigmoid")(hidden)
self.ReverseModel = keras.Model(inputs, outputs, name="ReverseModel")
@@ -102,7 +104,7 @@ class NeuralSolver:
def spec(self):
return f"{self.Width}x{self.Height}x{self.QuotientX}x{self.QuotientY}"
return f"{self.Width}x{self.Height}x{self.QuotientX}x{self.QuotientY}.weights"
def save_model(self):

6
app.py
View File

@@ -56,10 +56,13 @@ def train(request: TrainRequest, background_tasks: BackgroundTasks):
def try_load(request: TrainRequest, background_tasks: BackgroundTasks):
global solver
global status
if status != "initialized":
return {"Status": "model not initialized"}
if solver.load_model():
status = "ready"
return {"Status": "ok"}
background_tasks.add_task(lambda : task(request))
else:
background_tasks.add_task(lambda : task(request))
class BoardRequest(BaseModel):
@@ -115,7 +118,6 @@ def load():
return {"Status": "model loading failed"}
@app.get("/status")
def get_status():
global status