126 lines
3.1 KiB
Python
126 lines
3.1 KiB
Python
import itertools
|
|
from fastapi import FastAPI, BackgroundTasks
|
|
from pydantic import BaseModel
|
|
import numpy as np
|
|
from NeuralSolver import NeuralSolver
|
|
from typing import List, Tuple
|
|
|
|
app = FastAPI()
|
|
solver: NeuralSolver | None = None
|
|
status: str = "none"
|
|
|
|
def task(request):
|
|
global solver
|
|
global status
|
|
status = "training forward"
|
|
solver.train_forward(request.ForwardDatasetSize, request.ForwardBatchSize, request.ForwardEpochs)
|
|
status = "training backward"
|
|
solver.train_backward(request.BackwardDatasetSize, request.BackwardBatchSize, request.BackwardEpochs)
|
|
status = "ready"
|
|
|
|
class InitRequest(BaseModel):
|
|
Width: int
|
|
Height: int
|
|
QuotientX: bool = False
|
|
QuotientY: bool = False
|
|
|
|
@app.post("/initialize")
|
|
def initialize(request: InitRequest):
|
|
global solver
|
|
global status
|
|
if status != "none":
|
|
return {"Status": "instance already existed"}
|
|
solver = NeuralSolver(request.Width, request.Height, request.QuotientX, request.QuotientY)
|
|
status = "initialized"
|
|
return {"Status": status}
|
|
|
|
|
|
class TrainRequest(BaseModel):
|
|
ForwardDatasetSize: int = 1000
|
|
ForwardBatchSize: int = 8
|
|
ForwardEpochs: int = 10
|
|
BackwardDatasetSize: int = 1000
|
|
BackwardBatchSize: int = 8
|
|
BackwardEpochs: int = 10
|
|
|
|
|
|
@app.post("/train")
|
|
def train(request: TrainRequest, background_tasks: BackgroundTasks):
|
|
global solver
|
|
global status
|
|
if status != "initialized":
|
|
return {"Status": "model not initialized"}
|
|
background_tasks.add_task(lambda : task(request))
|
|
|
|
@app.post("/try_load")
|
|
def try_load(request: TrainRequest, background_tasks: BackgroundTasks):
|
|
global solver
|
|
global status
|
|
if solver.load_model():
|
|
status = "ready"
|
|
return {"Status": "ok"}
|
|
background_tasks.add_task(lambda : task(request))
|
|
|
|
|
|
class BoardRequest(BaseModel):
|
|
Lives: List[Tuple[int, int]]
|
|
Direction: str
|
|
|
|
@app.post("/predict")
|
|
def predict(request: BoardRequest):
|
|
global solver
|
|
global status
|
|
if status != "ready":
|
|
return {"Status": "model not ready yet"}
|
|
inputs = np.zeros((1, solver.Width, solver.Height, 1))
|
|
for (x, y) in request.Lives:
|
|
inputs[0, x, y, 0] = 1.0
|
|
res = None
|
|
if request.Direction == "forward":
|
|
res = solver.predict_forward(inputs)
|
|
else:
|
|
res = solver.predict_reverse(inputs)
|
|
lives = set()
|
|
for (x, y) in itertools.product(range(solver.Width), range(solver.Height)):
|
|
if res[0, x, y, 0]:
|
|
lives.add((x, y))
|
|
return {"Lives": lives}
|
|
|
|
@app.post("/finish")
|
|
def finish():
|
|
global solver
|
|
global status
|
|
solver = None
|
|
status = "none"
|
|
|
|
|
|
@app.post("/save")
|
|
def save():
|
|
global solver
|
|
global status
|
|
if status != "ready":
|
|
return {"Status": "model not ready yet"}
|
|
solver.save_model()
|
|
return {"Status": "ok"}
|
|
|
|
@app.post("/load")
|
|
def load():
|
|
global solver
|
|
global status
|
|
if status != "initialized":
|
|
return {"Status": "model not initialized yet"}
|
|
if solver.load_model():
|
|
status = "ready"
|
|
return {"Status": "ok"}
|
|
return {"Status": "model loading failed"}
|
|
|
|
|
|
|
|
@app.get("/status")
|
|
def get_status():
|
|
global status
|
|
return {"Status": status}
|
|
|
|
|
|
|