import itertools from fastapi import FastAPI, BackgroundTasks from pydantic import BaseModel import numpy as np from NeuralSolver import NeuralSolver from typing import List, Tuple app = FastAPI() solver: NeuralSolver | None = None status: str = "none" def task(request): global solver global status status = "training forward" solver.train_forward(request.ForwardDatasetSize, request.ForwardBatchSize, request.ForwardEpochs) status = "training backward" solver.train_backward(request.BackwardDatasetSize, request.BackwardBatchSize, request.BackwardEpochs) status = "ready" class InitRequest(BaseModel): Width: int Height: int QuotientX: bool = False QuotientY: bool = False @app.post("/initialize") def initialize(request: InitRequest): global solver global status if status != "none": return {"Status": "instance already existed"} solver = NeuralSolver(request.Width, request.Height, request.QuotientX, request.QuotientY) status = "initialized" return {"Status": status} class TrainRequest(BaseModel): ForwardDatasetSize: int = 1000 ForwardBatchSize: int = 8 ForwardEpochs: int = 10 BackwardDatasetSize: int = 1000 BackwardBatchSize: int = 8 BackwardEpochs: int = 10 @app.post("/train") def train(request: TrainRequest, background_tasks: BackgroundTasks): global solver global status if status != "initialized": return {"Status": "model not initialized"} background_tasks.add_task(lambda : task(request)) @app.post("/try_load") def try_load(request: TrainRequest, background_tasks: BackgroundTasks): global solver global status if status != "initialized": return {"Status": "model not initialized"} if solver.load_model(): status = "ready" return {"Status": "ok"} else: background_tasks.add_task(lambda : task(request)) class BoardRequest(BaseModel): Lives: List[Tuple[int, int]] Direction: str @app.post("/predict") def predict(request: BoardRequest): global solver global status if status != "ready": return {"Status": "model not ready yet"} inputs = np.zeros((1, solver.Width, solver.Height, 1)) for (x, y) in request.Lives: inputs[0, x, y, 0] = 1.0 res = None if request.Direction == "forward": res = solver.predict_forward(inputs) else: res = solver.predict_reverse(inputs) lives = set() for (x, y) in itertools.product(range(solver.Width), range(solver.Height)): if res[0, x, y, 0]: lives.add((x, y)) return {"Lives": lives} @app.post("/finish") def finish(): global solver global status solver = None status = "none" @app.post("/save") def save(): global solver global status if status != "ready": return {"Status": "model not ready yet"} solver.save_model() return {"Status": "ok"} @app.post("/load") def load(): global solver global status if status != "initialized": return {"Status": "model not initialized yet"} if solver.load_model(): status = "ready" return {"Status": "ok"} return {"Status": "model loading failed"} @app.get("/status") def get_status(): global status return {"Status": status}