diff --git a/NeuralSolver/NeuralSolver.py b/NeuralSolver/NeuralSolver.py index 44c2f84..4bf3392 100644 --- a/NeuralSolver/NeuralSolver.py +++ b/NeuralSolver/NeuralSolver.py @@ -4,6 +4,7 @@ from tensorflow import keras from Board import Board import random +import os class NeuralSolver: def __init__(self, width, height, quotient_x=False, quotient_y=False): @@ -97,4 +98,21 @@ class NeuralSolver: board.evaluate() for (cx, cy) in board.lives: y[i, cx, cy, 0] = 1.0 - return x, y \ No newline at end of file + return x, y + + + def spec(self): + return f"{self.Width}x{self.Height}x{self.QuotientX}x{self.QuotientY}" + + + def save_model(self): + self.ForwardModel.save_weights(f"ForwardModel_{self.spec()}.h5") + self.ReverseModel.save_weights(f"ReverseModel_{self.spec()}.h5") + + + def load_model(self): + if os.path.exists(f"ForwardModel_{self.spec()}.h5") and os.path.exists(f"ReverseModel_{self.spec()}.h5"): + self.ForwardModel.load_weights(f"ForwardModel_{self.spec()}.h5") + self.ReverseModel.load_weights(f"ReverseModel_{self.spec()}.h5") + return True + return False \ No newline at end of file diff --git a/app.py b/app.py index aca14ff..042a15f 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,4 @@ import itertools - from fastapi import FastAPI, BackgroundTasks from pydantic import BaseModel import numpy as np @@ -10,47 +9,74 @@ app = FastAPI() solver: NeuralSolver | None = None status: str = "none" -def task(): +def task(request): global solver global status status = "training forward" - solver.train_forward(1000) + solver.train_forward(request.ForwardDatasetSize, request.ForwardBatchSize, request.ForwardEpochs) status = "training backward" - solver.train_backward(1000) - status = "trained" + solver.train_backward(request.BackwardDatasetSize, request.BackwardBatchSize, request.BackwardEpochs) + status = "ready" class InitRequest(BaseModel): - width: int - height: int - quotientX: bool = False - quotientY: bool = False + Width: int + Height: int + QuotientX: bool = False + QuotientY: bool = False @app.post("/initialize") -def initialize(request: InitRequest, background_tasks: BackgroundTasks): +def initialize(request: InitRequest): global solver global status if status != "none": - return {"status": "instance already existed"} - solver = NeuralSolver(request.width, request.height, request.quotientX, request.quotientY) - background_tasks.add_task(task) - return {"status": "initializing"} + return {"Status": "instance already existed"} + solver = NeuralSolver(request.Width, request.Height, request.QuotientX, request.QuotientY) + status = "initialized" + return {"Status": status} + + +class TrainRequest(BaseModel): + ForwardDatasetSize: int = 1000 + ForwardBatchSize: int = 8 + ForwardEpochs: int = 10 + BackwardDatasetSize: int = 1000 + BackwardBatchSize: int = 8 + BackwardEpochs: int = 10 + + +@app.post("/train") +def train(request: TrainRequest, background_tasks: BackgroundTasks): + global solver + global status + if status != "initialized": + return {"Status": "model not initialized"} + background_tasks.add_task(lambda : task(request)) + +@app.post("/try_load") +def try_load(request: TrainRequest, background_tasks: BackgroundTasks): + global solver + global status + if solver.load_model(): + status = "ready" + return {"Status": "ok"} + background_tasks.add_task(lambda : task(request)) class BoardRequest(BaseModel): - lives: List[Tuple[int, int]] - direction: str + Lives: List[Tuple[int, int]] + Direction: str @app.post("/predict") def predict(request: BoardRequest): global solver global status - if status != "trained": - return {"status": "not trained yet"} + if status != "ready": + return {"Status": "model not ready yet"} inputs = np.zeros((1, solver.Width, solver.Height, 1)) - for (x, y) in request.lives: + for (x, y) in request.Lives: inputs[0, x, y, 0] = 1.0 res = None - if request.direction == "forward": + if request.Direction == "forward": res = solver.predict_forward(inputs) else: res = solver.predict_reverse(inputs) @@ -58,7 +84,7 @@ def predict(request: BoardRequest): for (x, y) in itertools.product(range(solver.Width), range(solver.Height)): if res[0, x, y, 0]: lives.add((x, y)) - return {"prediction": lives} + return {"Lives": lives} @app.post("/finish") def finish(): @@ -68,4 +94,32 @@ def finish(): status = "none" +@app.post("/save") +def save(): + global solver + global status + if status != "ready": + return {"Status": "model not ready yet"} + solver.save_model() + return {"Status": "ok"} + +@app.post("/load") +def load(): + global solver + global status + if status != "initialized": + return {"Status": "model not initialized yet"} + if solver.load_model(): + status = "ready" + return {"Status": "ok"} + return {"Status": "model loading failed"} + + + +@app.get("/status") +def get_status(): + global status + return {"Status": status} + + diff --git a/requirements.txt b/requirements.txt index bd63fdb..2407bea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,5 @@ -numpy~=2.0.2 -tensorflow~=2.18.0 \ No newline at end of file +numpy>=2.0.2 +fastapi>=0.115.7 +pydantic>=2.10.6 +setuptools>=75.8.0 +tensorflow>=2.18.0 \ No newline at end of file