Files
2025-01-29 21:24:10 +00:00

128 lines
3.2 KiB
Python

import itertools
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
import numpy as np
from NeuralSolver import NeuralSolver
from typing import List, Tuple
app = FastAPI()
solver: NeuralSolver | None = None
status: str = "none"
def task(request):
global solver
global status
status = "training forward"
solver.train_forward(request.ForwardDatasetSize, request.ForwardBatchSize, request.ForwardEpochs)
status = "training backward"
solver.train_backward(request.BackwardDatasetSize, request.BackwardBatchSize, request.BackwardEpochs)
status = "ready"
class InitRequest(BaseModel):
Width: int
Height: int
QuotientX: bool = False
QuotientY: bool = False
@app.post("/initialize")
def initialize(request: InitRequest):
global solver
global status
if status != "none":
return {"Status": "instance already existed"}
solver = NeuralSolver(request.Width, request.Height, request.QuotientX, request.QuotientY)
status = "initialized"
return {"Status": status}
class TrainRequest(BaseModel):
ForwardDatasetSize: int = 1000
ForwardBatchSize: int = 8
ForwardEpochs: int = 10
BackwardDatasetSize: int = 1000
BackwardBatchSize: int = 8
BackwardEpochs: int = 10
@app.post("/train")
def train(request: TrainRequest, background_tasks: BackgroundTasks):
global solver
global status
if status != "initialized":
return {"Status": "model not initialized"}
background_tasks.add_task(lambda : task(request))
@app.post("/try_load")
def try_load(request: TrainRequest, background_tasks: BackgroundTasks):
global solver
global status
if status != "initialized":
return {"Status": "model not initialized"}
if solver.load_model():
status = "ready"
return {"Status": "ok"}
else:
background_tasks.add_task(lambda : task(request))
class BoardRequest(BaseModel):
Lives: List[Tuple[int, int]]
Direction: str
@app.post("/predict")
def predict(request: BoardRequest):
global solver
global status
if status != "ready":
return {"Status": "model not ready yet"}
inputs = np.zeros((1, solver.Width, solver.Height, 1))
for (x, y) in request.Lives:
inputs[0, x, y, 0] = 1.0
res = None
if request.Direction == "forward":
res = solver.predict_forward(inputs)
else:
res = solver.predict_reverse(inputs)
lives = set()
for (x, y) in itertools.product(range(solver.Width), range(solver.Height)):
if res[0, x, y, 0]:
lives.add((x, y))
return {"Lives": lives}
@app.post("/finish")
def finish():
global solver
global status
solver = None
status = "none"
@app.post("/save")
def save():
global solver
global status
if status != "ready":
return {"Status": "model not ready yet"}
solver.save_model()
return {"Status": "ok"}
@app.post("/load")
def load():
global solver
global status
if status != "initialized":
return {"Status": "model not initialized yet"}
if solver.load_model():
status = "ready"
return {"Status": "ok"}
return {"Status": "model loading failed"}
@app.get("/status")
def get_status():
global status
return {"Status": status}