Numerus-V1 / main.py
SkillForge45's picture
Update main.py
8dc3336 verified
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.utils import to_categorical
# Constants
CHAR_SET = '0123456789+-=* /'
NUM_CLASSES = len(CHAR_SET)
MAX_EQUATION_LENGTH = 2000000
MAX_RESULT_LENGTH = 100000
def one_hot_encode(s, max_length):
encoding = np.zeros((max_length, NUM_CLASSES))
for i, char in enumerate(s[:max_length]):
if char in CHAR_SET:
char_index = CHAR_SET.index(char)
encoding[i, char_index] = 1
return encoding
def read_dataset(directory):
data = []
labels = []
for filename in os.listdir(directory):
if filename.endswith('.txt'):
with open(os.path.join(directory, filename), 'r') as file:
for line in file:
line = line.strip()
if '=' in line:
equation, result = line.split('=')
equation = equation.strip()
result = result.strip()
data.append(one_hot_encode(equation, MAX_EQUATION_LENGTH))
labels.append(one_hot_encode(result, MAX_RESULT_LENGTH))
return np.array(data), np.array(labels)
# Read dataset
data, labels = read_dataset('.math_train')
# Reshape labels for categorical crossentropy
labels = labels.reshape((labels.shape[0], -1, NUM_CLASSES))
# Build the model
model = Sequential([
Flatten(input_shape=(MAX_EQUATION_LENGTH, NUM_CLASSES)),
Dense(128, activation='relu'),
Dense(64, activation='relu'),
Dense(MAX_RESULT_LENGTH * NUM_CLASSES, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# Train the model
model.fit(data, labels.reshape((-1, MAX_RESULT_LENGTH * NUM_CLASSES)), epochs=50, batch_size=32)
# Function to solve an equation
def solve_equation(model, equation):
encoded_equation = one_hot_encode(equation, MAX_EQUATION_LENGTH)
input_tensor = np.expand_dims(encoded_equation, axis=0)
prediction = model.predict(input_tensor)
predicted_indices = np.argmax(prediction.reshape((MAX_RESULT_LENGTH, NUM_CLASSES)), axis=-1)
predicted_chars = ''.join(CHAR_SET[i] for i in predicted_indices if i < len(CHAR_SET))
return predicted_chars.strip()
equation = "1 + 1"
result = solve_equation(model, equation)
print(f"The result of '{equation}' is '{result}'")