add: as service

This commit is contained in:
h z
2025-01-26 16:13:49 +00:00
parent 1e6d11ebcf
commit 4f440e7e20
3 changed files with 145 additions and 11 deletions

71
app.py Normal file
View File

@@ -0,0 +1,71 @@
import itertools
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
import numpy as np
from NeuralSolver import NeuralSolver
from typing import List, Tuple
app = FastAPI()
solver: NeuralSolver | None = None
status: str = "none"
def task():
global solver
global status
status = "training forward"
solver.train_forward(1000)
status = "training backward"
solver.train_backward(1000)
status = "trained"
class InitRequest(BaseModel):
width: int
height: int
quotientX: bool = False
quotientY: bool = False
@app.post("/initialize")
def initialize(request: InitRequest, background_tasks: BackgroundTasks):
global solver
global status
if status != "none":
return {"status": "instance already existed"}
solver = NeuralSolver(request.width, request.height, request.quotientX, request.quotientY)
background_tasks.add_task(task)
return {"status": "initializing"}
class BoardRequest(BaseModel):
lives: List[Tuple[int, int]]
direction: str
@app.post("/predict")
def predict(request: BoardRequest):
global solver
global status
if status != "trained":
return {"status": "not trained yet"}
inputs = np.zeros((1, solver.Width, solver.Height, 1))
for (x, y) in request.lives:
inputs[0, x, y, 0] = 1.0
res = None
if request.direction == "forward":
res = solver.predict_forward(inputs)
else:
res = solver.predict_reverse(inputs)
lives = set()
for (x, y) in itertools.product(range(solver.Width), range(solver.Height)):
if res[0, x, y, 0]:
lives.add((x, y))
return {"prediction": lives}
@app.post("/finish")
def finish():
global solver
global status
solver = None
status = "none"