improve: provide api to save/load/get status
This commit is contained in:
@@ -4,6 +4,7 @@ from tensorflow import keras
|
||||
|
||||
from Board import Board
|
||||
import random
|
||||
import os
|
||||
|
||||
class NeuralSolver:
|
||||
def __init__(self, width, height, quotient_x=False, quotient_y=False):
|
||||
@@ -97,4 +98,21 @@ class NeuralSolver:
|
||||
board.evaluate()
|
||||
for (cx, cy) in board.lives:
|
||||
y[i, cx, cy, 0] = 1.0
|
||||
return x, y
|
||||
return x, y
|
||||
|
||||
|
||||
def spec(self):
|
||||
return f"{self.Width}x{self.Height}x{self.QuotientX}x{self.QuotientY}"
|
||||
|
||||
|
||||
def save_model(self):
|
||||
self.ForwardModel.save_weights(f"ForwardModel_{self.spec()}.h5")
|
||||
self.ReverseModel.save_weights(f"ReverseModel_{self.spec()}.h5")
|
||||
|
||||
|
||||
def load_model(self):
|
||||
if os.path.exists(f"ForwardModel_{self.spec()}.h5") and os.path.exists(f"ReverseModel_{self.spec()}.h5"):
|
||||
self.ForwardModel.load_weights(f"ForwardModel_{self.spec()}.h5")
|
||||
self.ReverseModel.load_weights(f"ReverseModel_{self.spec()}.h5")
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user