add: as service
This commit is contained in:
71
app.py
Normal file
71
app.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import itertools
|
||||
|
||||
from fastapi import FastAPI, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
import numpy as np
|
||||
from NeuralSolver import NeuralSolver
|
||||
from typing import List, Tuple
|
||||
|
||||
app = FastAPI()
|
||||
solver: NeuralSolver | None = None
|
||||
status: str = "none"
|
||||
|
||||
def task():
|
||||
global solver
|
||||
global status
|
||||
status = "training forward"
|
||||
solver.train_forward(1000)
|
||||
status = "training backward"
|
||||
solver.train_backward(1000)
|
||||
status = "trained"
|
||||
|
||||
class InitRequest(BaseModel):
|
||||
width: int
|
||||
height: int
|
||||
quotientX: bool = False
|
||||
quotientY: bool = False
|
||||
|
||||
@app.post("/initialize")
|
||||
def initialize(request: InitRequest, background_tasks: BackgroundTasks):
|
||||
global solver
|
||||
global status
|
||||
if status != "none":
|
||||
return {"status": "instance already existed"}
|
||||
solver = NeuralSolver(request.width, request.height, request.quotientX, request.quotientY)
|
||||
background_tasks.add_task(task)
|
||||
return {"status": "initializing"}
|
||||
|
||||
|
||||
class BoardRequest(BaseModel):
|
||||
lives: List[Tuple[int, int]]
|
||||
direction: str
|
||||
|
||||
@app.post("/predict")
|
||||
def predict(request: BoardRequest):
|
||||
global solver
|
||||
global status
|
||||
if status != "trained":
|
||||
return {"status": "not trained yet"}
|
||||
inputs = np.zeros((1, solver.Width, solver.Height, 1))
|
||||
for (x, y) in request.lives:
|
||||
inputs[0, x, y, 0] = 1.0
|
||||
res = None
|
||||
if request.direction == "forward":
|
||||
res = solver.predict_forward(inputs)
|
||||
else:
|
||||
res = solver.predict_reverse(inputs)
|
||||
lives = set()
|
||||
for (x, y) in itertools.product(range(solver.Width), range(solver.Height)):
|
||||
if res[0, x, y, 0]:
|
||||
lives.add((x, y))
|
||||
return {"prediction": lives}
|
||||
|
||||
@app.post("/finish")
|
||||
def finish():
|
||||
global solver
|
||||
global status
|
||||
solver = None
|
||||
status = "none"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user