Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| import tensorflow as tf | |
| import numpy as np | |
| # Load models | |
| MODELS = [ | |
| ("Bayes Enron1 spam", BAYES := "bayes-enron1-spam"), | |
| ("NN Enron1 spam", NN := "nn-enron1-spam"), | |
| ("GISTy Enron1 spam", LLM := "gisty-enron1-spam"), | |
| ] | |
| model_probs_path = hf_hub_download(repo_id="tbitai/bayes-enron1-spam", filename="probs.json") | |
| with open(model_probs_path) as f: | |
| model_probs = json.load(f) | |
| nn_model_path = hf_hub_download(repo_id="tbitai/nn-enron1-spam", filename="nn-enron1-spam.keras") | |
| nn_model = tf.keras.models.load_model(nn_model_path) | |
| llm_model_path = hf_hub_download(repo_id="tbitai/gisty-enron1-spam", filename="gisty-enron1-spam.keras") | |
| llm_model = tf.keras.models.load_model(llm_model_path) | |
| # Sentence Transformers should be imported after Keras models, in order to prevent it from setting Keras to legacy. | |
| from sentence_transformers import SentenceTransformer | |
| st_model = SentenceTransformer("avsolatorio/GIST-large-Embedding-v0") | |
| # Utils for Bayes | |
| UNK = '[UNK]' | |
| def tokenize(text): | |
| return tf.keras.preprocessing.text.text_to_word_sequence(text) | |
| def combine(probs): | |
| if any(p == 0 for p in probs): | |
| return 0 | |
| prod = np.prod(probs) | |
| neg_prod = np.prod([1 - p for p in probs]) | |
| if prod + neg_prod == 0: # Still possible due to floating point arithmetic | |
| return 0.5 # Assume that prod and neg_prod are equally small | |
| return prod / (prod + neg_prod) | |
| def get_interesting_probs(probs, intr_threshold): | |
| return sorted(probs, | |
| key=lambda p: abs(p - 0.5), | |
| reverse=True)[:intr_threshold] | |
| DEFAULT_INTR_THRESHOLD = 15 | |
| def unbias(p): | |
| return (2 * p) / (p + 1) | |
| # Predict functions | |
| def predict_bayes(text, intr_threshold, unbiased=False): | |
| words = tokenize(text) | |
| probs = [] | |
| for w in words: | |
| try: | |
| p = model_probs[w] | |
| if unbiased: | |
| p = unbias(p) | |
| except KeyError: | |
| p = model_probs[UNK] | |
| probs.append(p) | |
| interesting_probs = get_interesting_probs(probs, intr_threshold) | |
| return combine(interesting_probs) | |
| def predict_nn(text): | |
| return float(nn_model(np.array([text]))[0][0].numpy()) | |
| def predict_llm(text): | |
| embedding = st_model.encode(text) | |
| return float(llm_model(np.array([embedding]))[0][0].numpy()) | |
| def predict(model, input_txt, unbiased, intr_threshold): | |
| if model == BAYES: | |
| return predict_bayes(input_txt, unbiased=unbiased, intr_threshold=intr_threshold) | |
| elif model == NN: | |
| return predict_nn(input_txt) | |
| elif model == LLM: | |
| return predict_llm(input_txt) | |
| # UI | |
| demo = gr.Interface( | |
| theme=gr.themes.Origin( # Gradio 4-like | |
| primary_hue="yellow", | |
| ), | |
| fn=predict, | |
| inputs=[ | |
| gr.Dropdown(choices=MODELS, value=BAYES, label="Model", | |
| info="Learn more about the models [here](https://huggingface.co/collections/tbitai/bayes-or-spam-6700033fa145e298ec849249)"), | |
| gr.TextArea(label="Email"), | |
| ], | |
| additional_inputs_accordion=gr.Accordion("Additional configuration for Bayes", open=False), | |
| additional_inputs=[ | |
| gr.Checkbox(label="Unbias", info="Correct Graham's bias?"), | |
| gr.Slider(minimum=1, maximum=DEFAULT_INTR_THRESHOLD + 5, step=1, value=DEFAULT_INTR_THRESHOLD, | |
| label="Interestingness threshold", | |
| info=f"How many of the most interesting words to select in the probability calculation? ({DEFAULT_INTR_THRESHOLD} for Graham)"), | |
| ], | |
| outputs=[gr.Number(label="Spam probability")], | |
| title="Bayes or Spam?", | |
| description="Choose your model, and predict if your email is a spam! 📨", | |
| examples=[ | |
| [NN, "Enron actuals for June 26, 2000", None, None], | |
| [BAYES, "Stop the aging clock\nNerissa", True, DEFAULT_INTR_THRESHOLD], | |
| ], | |
| article="This is a demo of the models in the [Bayes or Spam?](https://github.com/tbitai/bayes-or-spam) project.", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |