improve: provide api to save/load/get status

This commit is contained in:
h z
2025-01-28 13:55:27 +00:00
parent 4f440e7e20
commit 34e9b95145
3 changed files with 99 additions and 24 deletions

View File

@@ -4,6 +4,7 @@ from tensorflow import keras
from Board import Board from Board import Board
import random import random
import os
class NeuralSolver: class NeuralSolver:
def __init__(self, width, height, quotient_x=False, quotient_y=False): def __init__(self, width, height, quotient_x=False, quotient_y=False):
@@ -98,3 +99,20 @@ class NeuralSolver:
for (cx, cy) in board.lives: for (cx, cy) in board.lives:
y[i, cx, cy, 0] = 1.0 y[i, cx, cy, 0] = 1.0
return x, y return x, y
def spec(self):
return f"{self.Width}x{self.Height}x{self.QuotientX}x{self.QuotientY}"
def save_model(self):
self.ForwardModel.save_weights(f"ForwardModel_{self.spec()}.h5")
self.ReverseModel.save_weights(f"ReverseModel_{self.spec()}.h5")
def load_model(self):
if os.path.exists(f"ForwardModel_{self.spec()}.h5") and os.path.exists(f"ReverseModel_{self.spec()}.h5"):
self.ForwardModel.load_weights(f"ForwardModel_{self.spec()}.h5")
self.ReverseModel.load_weights(f"ReverseModel_{self.spec()}.h5")
return True
return False

96
app.py
View File

@@ -1,5 +1,4 @@
import itertools import itertools
from fastapi import FastAPI, BackgroundTasks from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel from pydantic import BaseModel
import numpy as np import numpy as np
@@ -10,47 +9,74 @@ app = FastAPI()
solver: NeuralSolver | None = None solver: NeuralSolver | None = None
status: str = "none" status: str = "none"
def task(): def task(request):
global solver global solver
global status global status
status = "training forward" status = "training forward"
solver.train_forward(1000) solver.train_forward(request.ForwardDatasetSize, request.ForwardBatchSize, request.ForwardEpochs)
status = "training backward" status = "training backward"
solver.train_backward(1000) solver.train_backward(request.BackwardDatasetSize, request.BackwardBatchSize, request.BackwardEpochs)
status = "trained" status = "ready"
class InitRequest(BaseModel): class InitRequest(BaseModel):
width: int Width: int
height: int Height: int
quotientX: bool = False QuotientX: bool = False
quotientY: bool = False QuotientY: bool = False
@app.post("/initialize") @app.post("/initialize")
def initialize(request: InitRequest, background_tasks: BackgroundTasks): def initialize(request: InitRequest):
global solver global solver
global status global status
if status != "none": if status != "none":
return {"status": "instance already existed"} return {"Status": "instance already existed"}
solver = NeuralSolver(request.width, request.height, request.quotientX, request.quotientY) solver = NeuralSolver(request.Width, request.Height, request.QuotientX, request.QuotientY)
background_tasks.add_task(task) status = "initialized"
return {"status": "initializing"} return {"Status": status}
class TrainRequest(BaseModel):
ForwardDatasetSize: int = 1000
ForwardBatchSize: int = 8
ForwardEpochs: int = 10
BackwardDatasetSize: int = 1000
BackwardBatchSize: int = 8
BackwardEpochs: int = 10
@app.post("/train")
def train(request: TrainRequest, background_tasks: BackgroundTasks):
global solver
global status
if status != "initialized":
return {"Status": "model not initialized"}
background_tasks.add_task(lambda : task(request))
@app.post("/try_load")
def try_load(request: TrainRequest, background_tasks: BackgroundTasks):
global solver
global status
if solver.load_model():
status = "ready"
return {"Status": "ok"}
background_tasks.add_task(lambda : task(request))
class BoardRequest(BaseModel): class BoardRequest(BaseModel):
lives: List[Tuple[int, int]] Lives: List[Tuple[int, int]]
direction: str Direction: str
@app.post("/predict") @app.post("/predict")
def predict(request: BoardRequest): def predict(request: BoardRequest):
global solver global solver
global status global status
if status != "trained": if status != "ready":
return {"status": "not trained yet"} return {"Status": "model not ready yet"}
inputs = np.zeros((1, solver.Width, solver.Height, 1)) inputs = np.zeros((1, solver.Width, solver.Height, 1))
for (x, y) in request.lives: for (x, y) in request.Lives:
inputs[0, x, y, 0] = 1.0 inputs[0, x, y, 0] = 1.0
res = None res = None
if request.direction == "forward": if request.Direction == "forward":
res = solver.predict_forward(inputs) res = solver.predict_forward(inputs)
else: else:
res = solver.predict_reverse(inputs) res = solver.predict_reverse(inputs)
@@ -58,7 +84,7 @@ def predict(request: BoardRequest):
for (x, y) in itertools.product(range(solver.Width), range(solver.Height)): for (x, y) in itertools.product(range(solver.Width), range(solver.Height)):
if res[0, x, y, 0]: if res[0, x, y, 0]:
lives.add((x, y)) lives.add((x, y))
return {"prediction": lives} return {"Lives": lives}
@app.post("/finish") @app.post("/finish")
def finish(): def finish():
@@ -68,4 +94,32 @@ def finish():
status = "none" status = "none"
@app.post("/save")
def save():
global solver
global status
if status != "ready":
return {"Status": "model not ready yet"}
solver.save_model()
return {"Status": "ok"}
@app.post("/load")
def load():
global solver
global status
if status != "initialized":
return {"Status": "model not initialized yet"}
if solver.load_model():
status = "ready"
return {"Status": "ok"}
return {"Status": "model loading failed"}
@app.get("/status")
def get_status():
global status
return {"Status": status}

View File

@@ -1,2 +1,5 @@
numpy~=2.0.2 numpy>=2.0.2
tensorflow~=2.18.0 fastapi>=0.115.7
pydantic>=2.10.6
setuptools>=75.8.0
tensorflow>=2.18.0