add: data generator
This commit is contained in:
@@ -15,6 +15,10 @@ class Board:
|
|||||||
self.lives.add((x, y))
|
self.lives.add((x, y))
|
||||||
|
|
||||||
def evaluate(self):
|
def evaluate(self):
|
||||||
|
new_lives = self.try_evaluate()
|
||||||
|
self.lives = new_lives
|
||||||
|
|
||||||
|
def try_evaluate(self) -> set[(int, int)]:
|
||||||
new_lives = set()
|
new_lives = set()
|
||||||
for (x, y) in itertools.product(range(self.width), range(self.height)):
|
for (x, y) in itertools.product(range(self.width), range(self.height)):
|
||||||
neighbor_count = 0
|
neighbor_count = 0
|
||||||
@@ -33,5 +37,18 @@ class Board:
|
|||||||
new_lives.add((x, y))
|
new_lives.add((x, y))
|
||||||
if (x, y) not in self.lives and neighbor_count == 3:
|
if (x, y) not in self.lives and neighbor_count == 3:
|
||||||
new_lives.add((x, y))
|
new_lives.add((x, y))
|
||||||
self.lives = new_lives
|
return new_lives
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
res = ""
|
||||||
|
for y in range(self.height):
|
||||||
|
for x in range(self.width):
|
||||||
|
if (x, y) in self.lives:
|
||||||
|
res += "o"
|
||||||
|
else:
|
||||||
|
res += " "
|
||||||
|
res += "\n"
|
||||||
|
return res
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
56
DataGenerator/CircleGenerator.py
Normal file
56
DataGenerator/CircleGenerator.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
from Board import Board
|
||||||
|
from DataGenerator import DataGenerator
|
||||||
|
|
||||||
|
def _circle_generator(cx, cy, px, py):
|
||||||
|
yield cx + px, cy + py
|
||||||
|
yield cx - px, cy + py
|
||||||
|
yield cx + px, cy - py
|
||||||
|
yield cx - px, cy - py
|
||||||
|
yield cx + py, cy + px
|
||||||
|
yield cx - py, cy + px
|
||||||
|
yield cx + py, cy - px
|
||||||
|
yield cx - py, cy - px
|
||||||
|
|
||||||
|
def circle_generator(cx, cy, radius):
|
||||||
|
x = 0
|
||||||
|
y = radius
|
||||||
|
d = 1 - radius
|
||||||
|
for tx, ty in _circle_generator(cx, cy, x, y):
|
||||||
|
yield tx, ty
|
||||||
|
while x < y:
|
||||||
|
x += 1
|
||||||
|
if d < 0:
|
||||||
|
d += 2 * x + 1
|
||||||
|
else:
|
||||||
|
y -= 1
|
||||||
|
d += 2 * (x - y) + 1
|
||||||
|
for tx, ty in _circle_generator(cx, cy, x, y):
|
||||||
|
yield tx, ty
|
||||||
|
|
||||||
|
class CircleGenerator(DataGenerator):
|
||||||
|
def generate(self, amount: int) -> list[Board]:
|
||||||
|
res = []
|
||||||
|
generated = set()
|
||||||
|
for i in range(amount):
|
||||||
|
success = False
|
||||||
|
while not success:
|
||||||
|
success = True
|
||||||
|
b = Board(self.width, self.height, self.qx, self.qy)
|
||||||
|
cx, cy, radius = self.random_config()
|
||||||
|
while (cx, cy, radius) in generated:
|
||||||
|
cx, cy, radius = self.random_config()
|
||||||
|
for tx, ty in circle_generator(cx, cy, radius):
|
||||||
|
b.lives.add((tx, ty))
|
||||||
|
if len(b.try_evaluate()) == 0:
|
||||||
|
success = False
|
||||||
|
continue
|
||||||
|
res.append(b)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def random_config(self):
|
||||||
|
cx = random.randint(min(5, self.width), self.width - 1)
|
||||||
|
cy = random.randint(min(5, self.width), self.height - 1)
|
||||||
|
radius = random.randint(2, min(self.width, self.height))
|
||||||
|
return cx, cy, radius
|
||||||
17
DataGenerator/__init__.py
Normal file
17
DataGenerator/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from Board import Board
|
||||||
|
|
||||||
|
|
||||||
|
class DataGenerator:
|
||||||
|
def __init__(self, width, height, qx, qy):
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.qx = qx
|
||||||
|
self.qy = qy
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self, amount: int) -> list[Board]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
|
from tensorflow.keras import regularizers
|
||||||
from Board import Board
|
from Board import Board
|
||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
@@ -29,8 +29,10 @@ class NeuralSolver:
|
|||||||
|
|
||||||
def _build_reverse_model(self):
|
def _build_reverse_model(self):
|
||||||
inputs = keras.Input(shape=(self.Width, self.Height, 1), name="FinalState")
|
inputs = keras.Input(shape=(self.Width, self.Height, 1), name="FinalState")
|
||||||
hidden = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(inputs)
|
hidden = keras.layers.Conv2D(32, 3, padding="same", activation="relu",kernel_regularizer=regularizers.l2(1e-4))(inputs)
|
||||||
hidden = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(hidden)
|
hidden = keras.layers.BatchNormalization()(hidden)
|
||||||
|
hidden = keras.layers.Conv2D(32, 5, padding="same", activation="relu",kernel_regularizer=regularizers.l2(1e-4))(hidden)
|
||||||
|
hidden = keras.layers.Conv2D(32, 3, padding="same", activation="relu",kernel_regularizer=regularizers.l2(1e-4))(hidden)
|
||||||
outputs = keras.layers.Conv2D(1, 1, padding="same", activation="sigmoid")(hidden)
|
outputs = keras.layers.Conv2D(1, 1, padding="same", activation="sigmoid")(hidden)
|
||||||
self.ReverseModel = keras.Model(inputs, outputs, name="ReverseModel")
|
self.ReverseModel = keras.Model(inputs, outputs, name="ReverseModel")
|
||||||
|
|
||||||
@@ -102,7 +104,7 @@ class NeuralSolver:
|
|||||||
|
|
||||||
|
|
||||||
def spec(self):
|
def spec(self):
|
||||||
return f"{self.Width}x{self.Height}x{self.QuotientX}x{self.QuotientY}"
|
return f"{self.Width}x{self.Height}x{self.QuotientX}x{self.QuotientY}.weights"
|
||||||
|
|
||||||
|
|
||||||
def save_model(self):
|
def save_model(self):
|
||||||
|
|||||||
6
app.py
6
app.py
@@ -56,10 +56,13 @@ def train(request: TrainRequest, background_tasks: BackgroundTasks):
|
|||||||
def try_load(request: TrainRequest, background_tasks: BackgroundTasks):
|
def try_load(request: TrainRequest, background_tasks: BackgroundTasks):
|
||||||
global solver
|
global solver
|
||||||
global status
|
global status
|
||||||
|
if status != "initialized":
|
||||||
|
return {"Status": "model not initialized"}
|
||||||
if solver.load_model():
|
if solver.load_model():
|
||||||
status = "ready"
|
status = "ready"
|
||||||
return {"Status": "ok"}
|
return {"Status": "ok"}
|
||||||
background_tasks.add_task(lambda : task(request))
|
else:
|
||||||
|
background_tasks.add_task(lambda : task(request))
|
||||||
|
|
||||||
|
|
||||||
class BoardRequest(BaseModel):
|
class BoardRequest(BaseModel):
|
||||||
@@ -115,7 +118,6 @@ def load():
|
|||||||
return {"Status": "model loading failed"}
|
return {"Status": "model loading failed"}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/status")
|
@app.get("/status")
|
||||||
def get_status():
|
def get_status():
|
||||||
global status
|
global status
|
||||||
|
|||||||
Reference in New Issue
Block a user