Files
InverseOfLife.NeuralSolver/NeuralSolver/NeuralSolver.py
2025-01-29 21:24:10 +00:00

120 lines
4.8 KiB
Python

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import regularizers
from Board import Board
import random
import os
class NeuralSolver:
def __init__(self, width, height, quotient_x=False, quotient_y=False):
self.Width = width
self.Height = height
self.QuotientX = quotient_x
self.QuotientY = quotient_y
self._build_forward_model()
self._build_reverse_model()
def _build_forward_model(self):
inputs = keras.Input(shape=(self.Width, self.Height, 1), name="InitialState")
hidden = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(inputs)
hidden = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(hidden)
outputs = keras.layers.Conv2D(1, 1, padding="same", activation="sigmoid")(hidden)
self.ForwardModel = keras.Model(inputs, outputs, name="ForwardModel")
self.ForwardModel.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss=keras.losses.BinaryCrossentropy(),
metrics=["accuracy"],
)
def _build_reverse_model(self):
inputs = keras.Input(shape=(self.Width, self.Height, 1), name="FinalState")
hidden = keras.layers.Conv2D(32, 3, padding="same", activation="relu",kernel_regularizer=regularizers.l2(1e-4))(inputs)
hidden = keras.layers.BatchNormalization()(hidden)
hidden = keras.layers.Conv2D(32, 5, padding="same", activation="relu",kernel_regularizer=regularizers.l2(1e-4))(hidden)
hidden = keras.layers.Conv2D(32, 3, padding="same", activation="relu",kernel_regularizer=regularizers.l2(1e-4))(hidden)
outputs = keras.layers.Conv2D(1, 1, padding="same", activation="sigmoid")(hidden)
self.ReverseModel = keras.Model(inputs, outputs, name="ReverseModel")
def train_forward(self, dataset_size, batch_size=8, epochs=10):
x, y = self.generate_training_data(dataset_size)
self.ForwardModel.fit(
x=x,
y=y,
batch_size=batch_size,
epochs=epochs,
verbose=1
)
def train_backward(self, dataset_size, batch_size=8, epochs=10):
x, y = self.generate_training_data(dataset_size)
self.ForwardModel.trainable = False
self.ForwardModel.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss=keras.losses.BinaryCrossentropy(),
metrics=["accuracy"],
)
reverse_inputs = self.ReverseModel.inputs
reverse_outputs = self.ReverseModel.outputs
forward_outputs = self.ForwardModel(reverse_outputs)
composite_model = keras.Model(reverse_inputs, forward_outputs, name="CompositeModel")
composite_model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss=keras.losses.BinaryCrossentropy(),
metrics=["accuracy"]
)
composite_model.fit(
x = y,
y = y,
batch_size=batch_size,
epochs=epochs,
verbose=1
)
def predict_forward(self, b):
if len(b.shape) == 2:
b = b[None, ..., None]
preds = self.ForwardModel.predict(b)
return preds > 0.5
def predict_reverse(self, b):
if len(b.shape) == 2:
b = b[None, ..., None]
preds = self.ReverseModel.predict(b)
return preds > 0.5
def generate_training_data(self, dataset_size=1000):
x = np.zeros((dataset_size, self.Width, self.Height, 1), dtype=np.float32)
y = np.zeros((dataset_size, self.Width, self.Height, 1), dtype=np.float32)
for i in range(dataset_size):
board = Board(self.Width, self.Height, self.QuotientX, self.QuotientY)
ops = random.randint(self.Width * self.Height // 16, self.Width * self.Height)
for _ in range(ops):
x_ = random.randint(0, self.Width - 1)
y_ = random.randint(0, self.Height - 1)
board.toggle(x_, y_)
for (cx, cy) in board.lives:
x[i, cx, cy, 0] = 1.0
board.evaluate()
for (cx, cy) in board.lives:
y[i, cx, cy, 0] = 1.0
return x, y
def spec(self):
return f"{self.Width}x{self.Height}x{self.QuotientX}x{self.QuotientY}.weights"
def save_model(self):
self.ForwardModel.save_weights(f"ForwardModel_{self.spec()}.h5")
self.ReverseModel.save_weights(f"ReverseModel_{self.spec()}.h5")
def load_model(self):
if os.path.exists(f"ForwardModel_{self.spec()}.h5") and os.path.exists(f"ReverseModel_{self.spec()}.h5"):
self.ForwardModel.load_weights(f"ForwardModel_{self.spec()}.h5")
self.ReverseModel.load_weights(f"ReverseModel_{self.spec()}.h5")
return True
return False