improve: provide api to save/load/get status
This commit is contained in:
@@ -4,6 +4,7 @@ from tensorflow import keras
|
||||
|
||||
from Board import Board
|
||||
import random
|
||||
import os
|
||||
|
||||
class NeuralSolver:
|
||||
def __init__(self, width, height, quotient_x=False, quotient_y=False):
|
||||
@@ -97,4 +98,21 @@ class NeuralSolver:
|
||||
board.evaluate()
|
||||
for (cx, cy) in board.lives:
|
||||
y[i, cx, cy, 0] = 1.0
|
||||
return x, y
|
||||
return x, y
|
||||
|
||||
|
||||
def spec(self):
|
||||
return f"{self.Width}x{self.Height}x{self.QuotientX}x{self.QuotientY}"
|
||||
|
||||
|
||||
def save_model(self):
|
||||
self.ForwardModel.save_weights(f"ForwardModel_{self.spec()}.h5")
|
||||
self.ReverseModel.save_weights(f"ReverseModel_{self.spec()}.h5")
|
||||
|
||||
|
||||
def load_model(self):
|
||||
if os.path.exists(f"ForwardModel_{self.spec()}.h5") and os.path.exists(f"ReverseModel_{self.spec()}.h5"):
|
||||
self.ForwardModel.load_weights(f"ForwardModel_{self.spec()}.h5")
|
||||
self.ReverseModel.load_weights(f"ReverseModel_{self.spec()}.h5")
|
||||
return True
|
||||
return False
|
||||
96
app.py
96
app.py
@@ -1,5 +1,4 @@
|
||||
import itertools
|
||||
|
||||
from fastapi import FastAPI, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
import numpy as np
|
||||
@@ -10,47 +9,74 @@ app = FastAPI()
|
||||
solver: NeuralSolver | None = None
|
||||
status: str = "none"
|
||||
|
||||
def task():
|
||||
def task(request):
|
||||
global solver
|
||||
global status
|
||||
status = "training forward"
|
||||
solver.train_forward(1000)
|
||||
solver.train_forward(request.ForwardDatasetSize, request.ForwardBatchSize, request.ForwardEpochs)
|
||||
status = "training backward"
|
||||
solver.train_backward(1000)
|
||||
status = "trained"
|
||||
solver.train_backward(request.BackwardDatasetSize, request.BackwardBatchSize, request.BackwardEpochs)
|
||||
status = "ready"
|
||||
|
||||
class InitRequest(BaseModel):
|
||||
width: int
|
||||
height: int
|
||||
quotientX: bool = False
|
||||
quotientY: bool = False
|
||||
Width: int
|
||||
Height: int
|
||||
QuotientX: bool = False
|
||||
QuotientY: bool = False
|
||||
|
||||
@app.post("/initialize")
|
||||
def initialize(request: InitRequest, background_tasks: BackgroundTasks):
|
||||
def initialize(request: InitRequest):
|
||||
global solver
|
||||
global status
|
||||
if status != "none":
|
||||
return {"status": "instance already existed"}
|
||||
solver = NeuralSolver(request.width, request.height, request.quotientX, request.quotientY)
|
||||
background_tasks.add_task(task)
|
||||
return {"status": "initializing"}
|
||||
return {"Status": "instance already existed"}
|
||||
solver = NeuralSolver(request.Width, request.Height, request.QuotientX, request.QuotientY)
|
||||
status = "initialized"
|
||||
return {"Status": status}
|
||||
|
||||
|
||||
class TrainRequest(BaseModel):
|
||||
ForwardDatasetSize: int = 1000
|
||||
ForwardBatchSize: int = 8
|
||||
ForwardEpochs: int = 10
|
||||
BackwardDatasetSize: int = 1000
|
||||
BackwardBatchSize: int = 8
|
||||
BackwardEpochs: int = 10
|
||||
|
||||
|
||||
@app.post("/train")
|
||||
def train(request: TrainRequest, background_tasks: BackgroundTasks):
|
||||
global solver
|
||||
global status
|
||||
if status != "initialized":
|
||||
return {"Status": "model not initialized"}
|
||||
background_tasks.add_task(lambda : task(request))
|
||||
|
||||
@app.post("/try_load")
|
||||
def try_load(request: TrainRequest, background_tasks: BackgroundTasks):
|
||||
global solver
|
||||
global status
|
||||
if solver.load_model():
|
||||
status = "ready"
|
||||
return {"Status": "ok"}
|
||||
background_tasks.add_task(lambda : task(request))
|
||||
|
||||
|
||||
class BoardRequest(BaseModel):
|
||||
lives: List[Tuple[int, int]]
|
||||
direction: str
|
||||
Lives: List[Tuple[int, int]]
|
||||
Direction: str
|
||||
|
||||
@app.post("/predict")
|
||||
def predict(request: BoardRequest):
|
||||
global solver
|
||||
global status
|
||||
if status != "trained":
|
||||
return {"status": "not trained yet"}
|
||||
if status != "ready":
|
||||
return {"Status": "model not ready yet"}
|
||||
inputs = np.zeros((1, solver.Width, solver.Height, 1))
|
||||
for (x, y) in request.lives:
|
||||
for (x, y) in request.Lives:
|
||||
inputs[0, x, y, 0] = 1.0
|
||||
res = None
|
||||
if request.direction == "forward":
|
||||
if request.Direction == "forward":
|
||||
res = solver.predict_forward(inputs)
|
||||
else:
|
||||
res = solver.predict_reverse(inputs)
|
||||
@@ -58,7 +84,7 @@ def predict(request: BoardRequest):
|
||||
for (x, y) in itertools.product(range(solver.Width), range(solver.Height)):
|
||||
if res[0, x, y, 0]:
|
||||
lives.add((x, y))
|
||||
return {"prediction": lives}
|
||||
return {"Lives": lives}
|
||||
|
||||
@app.post("/finish")
|
||||
def finish():
|
||||
@@ -68,4 +94,32 @@ def finish():
|
||||
status = "none"
|
||||
|
||||
|
||||
@app.post("/save")
|
||||
def save():
|
||||
global solver
|
||||
global status
|
||||
if status != "ready":
|
||||
return {"Status": "model not ready yet"}
|
||||
solver.save_model()
|
||||
return {"Status": "ok"}
|
||||
|
||||
@app.post("/load")
|
||||
def load():
|
||||
global solver
|
||||
global status
|
||||
if status != "initialized":
|
||||
return {"Status": "model not initialized yet"}
|
||||
if solver.load_model():
|
||||
status = "ready"
|
||||
return {"Status": "ok"}
|
||||
return {"Status": "model loading failed"}
|
||||
|
||||
|
||||
|
||||
@app.get("/status")
|
||||
def get_status():
|
||||
global status
|
||||
return {"Status": status}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
numpy~=2.0.2
|
||||
tensorflow~=2.18.0
|
||||
numpy>=2.0.2
|
||||
fastapi>=0.115.7
|
||||
pydantic>=2.10.6
|
||||
setuptools>=75.8.0
|
||||
tensorflow>=2.18.0
|
||||
Reference in New Issue
Block a user