improve: provide api to save/load/get status

This commit is contained in:
h z
2025-01-28 13:55:27 +00:00
parent 4f440e7e20
commit 34e9b95145
3 changed files with 99 additions and 24 deletions

View File

@@ -4,6 +4,7 @@ from tensorflow import keras
from Board import Board
import random
import os
class NeuralSolver:
def __init__(self, width, height, quotient_x=False, quotient_y=False):
@@ -97,4 +98,21 @@ class NeuralSolver:
board.evaluate()
for (cx, cy) in board.lives:
y[i, cx, cy, 0] = 1.0
return x, y
return x, y
def spec(self):
return f"{self.Width}x{self.Height}x{self.QuotientX}x{self.QuotientY}"
def save_model(self):
self.ForwardModel.save_weights(f"ForwardModel_{self.spec()}.h5")
self.ReverseModel.save_weights(f"ReverseModel_{self.spec()}.h5")
def load_model(self):
if os.path.exists(f"ForwardModel_{self.spec()}.h5") and os.path.exists(f"ReverseModel_{self.spec()}.h5"):
self.ForwardModel.load_weights(f"ForwardModel_{self.spec()}.h5")
self.ReverseModel.load_weights(f"ReverseModel_{self.spec()}.h5")
return True
return False

96
app.py
View File

@@ -1,5 +1,4 @@
import itertools
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
import numpy as np
@@ -10,47 +9,74 @@ app = FastAPI()
solver: NeuralSolver | None = None
status: str = "none"
def task():
def task(request):
global solver
global status
status = "training forward"
solver.train_forward(1000)
solver.train_forward(request.ForwardDatasetSize, request.ForwardBatchSize, request.ForwardEpochs)
status = "training backward"
solver.train_backward(1000)
status = "trained"
solver.train_backward(request.BackwardDatasetSize, request.BackwardBatchSize, request.BackwardEpochs)
status = "ready"
class InitRequest(BaseModel):
width: int
height: int
quotientX: bool = False
quotientY: bool = False
Width: int
Height: int
QuotientX: bool = False
QuotientY: bool = False
@app.post("/initialize")
def initialize(request: InitRequest, background_tasks: BackgroundTasks):
def initialize(request: InitRequest):
global solver
global status
if status != "none":
return {"status": "instance already existed"}
solver = NeuralSolver(request.width, request.height, request.quotientX, request.quotientY)
background_tasks.add_task(task)
return {"status": "initializing"}
return {"Status": "instance already existed"}
solver = NeuralSolver(request.Width, request.Height, request.QuotientX, request.QuotientY)
status = "initialized"
return {"Status": status}
class TrainRequest(BaseModel):
ForwardDatasetSize: int = 1000
ForwardBatchSize: int = 8
ForwardEpochs: int = 10
BackwardDatasetSize: int = 1000
BackwardBatchSize: int = 8
BackwardEpochs: int = 10
@app.post("/train")
def train(request: TrainRequest, background_tasks: BackgroundTasks):
global solver
global status
if status != "initialized":
return {"Status": "model not initialized"}
background_tasks.add_task(lambda : task(request))
@app.post("/try_load")
def try_load(request: TrainRequest, background_tasks: BackgroundTasks):
global solver
global status
if solver.load_model():
status = "ready"
return {"Status": "ok"}
background_tasks.add_task(lambda : task(request))
class BoardRequest(BaseModel):
lives: List[Tuple[int, int]]
direction: str
Lives: List[Tuple[int, int]]
Direction: str
@app.post("/predict")
def predict(request: BoardRequest):
global solver
global status
if status != "trained":
return {"status": "not trained yet"}
if status != "ready":
return {"Status": "model not ready yet"}
inputs = np.zeros((1, solver.Width, solver.Height, 1))
for (x, y) in request.lives:
for (x, y) in request.Lives:
inputs[0, x, y, 0] = 1.0
res = None
if request.direction == "forward":
if request.Direction == "forward":
res = solver.predict_forward(inputs)
else:
res = solver.predict_reverse(inputs)
@@ -58,7 +84,7 @@ def predict(request: BoardRequest):
for (x, y) in itertools.product(range(solver.Width), range(solver.Height)):
if res[0, x, y, 0]:
lives.add((x, y))
return {"prediction": lives}
return {"Lives": lives}
@app.post("/finish")
def finish():
@@ -68,4 +94,32 @@ def finish():
status = "none"
@app.post("/save")
def save():
global solver
global status
if status != "ready":
return {"Status": "model not ready yet"}
solver.save_model()
return {"Status": "ok"}
@app.post("/load")
def load():
global solver
global status
if status != "initialized":
return {"Status": "model not initialized yet"}
if solver.load_model():
status = "ready"
return {"Status": "ok"}
return {"Status": "model loading failed"}
@app.get("/status")
def get_status():
global status
return {"Status": status}

View File

@@ -1,2 +1,5 @@
numpy~=2.0.2
tensorflow~=2.18.0
numpy>=2.0.2
fastapi>=0.115.7
pydantic>=2.10.6
setuptools>=75.8.0
tensorflow>=2.18.0