improve: provide api to save/load/get status

This commit is contained in:
h z
2025-01-28 13:55:27 +00:00
parent 4f440e7e20
commit 34e9b95145
3 changed files with 99 additions and 24 deletions

View File

@@ -4,6 +4,7 @@ from tensorflow import keras
from Board import Board
import random
import os
class NeuralSolver:
def __init__(self, width, height, quotient_x=False, quotient_y=False):
@@ -97,4 +98,21 @@ class NeuralSolver:
board.evaluate()
for (cx, cy) in board.lives:
y[i, cx, cy, 0] = 1.0
return x, y
return x, y
def spec(self):
return f"{self.Width}x{self.Height}x{self.QuotientX}x{self.QuotientY}"
def save_model(self):
self.ForwardModel.save_weights(f"ForwardModel_{self.spec()}.h5")
self.ReverseModel.save_weights(f"ReverseModel_{self.spec()}.h5")
def load_model(self):
if os.path.exists(f"ForwardModel_{self.spec()}.h5") and os.path.exists(f"ReverseModel_{self.spec()}.h5"):
self.ForwardModel.load_weights(f"ForwardModel_{self.spec()}.h5")
self.ReverseModel.load_weights(f"ReverseModel_{self.spec()}.h5")
return True
return False