import os import numpy as np import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Flatten from tensorflow.keras.utils import to_categorical # Constants CHAR_SET = '0123456789+-=* /' NUM_CLASSES = len(CHAR_SET) MAX_EQUATION_LENGTH = 2000000 MAX_RESULT_LENGTH = 100000 def one_hot_encode(s, max_length): encoding = np.zeros((max_length, NUM_CLASSES)) for i, char in enumerate(s[:max_length]): if char in CHAR_SET: char_index = CHAR_SET.index(char) encoding[i, char_index] = 1 return encoding def read_dataset(directory): data = [] labels = [] for filename in os.listdir(directory): if filename.endswith('.txt'): with open(os.path.join(directory, filename), 'r') as file: for line in file: line = line.strip() if '=' in line: equation, result = line.split('=') equation = equation.strip() result = result.strip() data.append(one_hot_encode(equation, MAX_EQUATION_LENGTH)) labels.append(one_hot_encode(result, MAX_RESULT_LENGTH)) return np.array(data), np.array(labels) # Read dataset data, labels = read_dataset('.math_train') # Reshape labels for categorical crossentropy labels = labels.reshape((labels.shape[0], -1, NUM_CLASSES)) # Build the model model = Sequential([ Flatten(input_shape=(MAX_EQUATION_LENGTH, NUM_CLASSES)), Dense(128, activation='relu'), Dense(64, activation='relu'), Dense(MAX_RESULT_LENGTH * NUM_CLASSES, activation='softmax') ]) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) # Train the model model.fit(data, labels.reshape((-1, MAX_RESULT_LENGTH * NUM_CLASSES)), epochs=50, batch_size=32) # Function to solve an equation def solve_equation(model, equation): encoded_equation = one_hot_encode(equation, MAX_EQUATION_LENGTH) input_tensor = np.expand_dims(encoded_equation, axis=0) prediction = model.predict(input_tensor) predicted_indices = np.argmax(prediction.reshape((MAX_RESULT_LENGTH, NUM_CLASSES)), axis=-1) predicted_chars = ''.join(CHAR_SET[i] for i in predicted_indices if i < len(CHAR_SET)) return predicted_chars.strip() equation = "1 + 1" result = solve_equation(model, equation) print(f"The result of '{equation}' is '{result}'")