nqzfaizal77ai's picture
Update app.py
6a285af verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
from bs4 import BeautifulSoup # Optional: in case HTML prompt cleanup is needed
# Title and description
title = "NQZFaizal77 AI Small Space Warfare CLM Test"
description = "Casual Language Model Testing Interface For Small Space Warfare Finetuned Model"
# Local model paths
MODEL_PATHS = {
"SwiftStrike Aero Space Warfare Intro 3POV": "nqzfaizal77ai/sa-145m-en-1bc-space-warfare-short-story-3pov-exp-intro",
"SwiftStrike Aero Space Warfare 3POV": "nqzfaizal77ai/sa-145m-en-1bc-space-warfare-short-story-3pov-exp",
"Halcyra Driftwing Space Warfare Intro 3POV" : "nqzfaizal77ai/hd-178m-en-1bc-space-warfare-short-story-3pov-exp-intro",
"Halcyra Driftwing Space Warfare 3POV" : "nqzfaizal77ai/hd-178m-en-1bc-space-warfare-short-story-3pov-exp",
"Noble Mind Space Warfare Intro 3POV" : "nqzfaizal77ai/nm-212m-en-1bc-space-warfare-short-story-3pov-exp-intro",
"Noble Mind Space Warfare 3POV" : "nqzfaizal77ai/nm-212m-en-1bc-space-warfare-short-story-3pov-exp",
"Mirabel Tempest Space Warfare Intro 3POV" : "nqzfaizal77ai/mt-230m-en-1bc-space-warfare-short-story-3pov-exp-intro",
"Mirabel Tempest Space Warfare 3POV" : "nqzfaizal77ai/mt-230m-en-1bc-space-warfare-short-story-3pov-exp"
}
DEFAULT_MODEL = "SwiftStrike Aero Space Warfare Intro 3POV"
# Cache loaded models
loaded_models = {}
loaded_tokenizers = {}
# Global stop flag
stop_generation = False
def load_model(model_name):
if model_name not in loaded_models:
model_path = MODEL_PATHS[model_name]
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float32,
device_map="auto",
low_cpu_mem_usage=True
)
loaded_models[model_name] = model
loaded_tokenizers[model_name] = tokenizer
return loaded_models[model_name], loaded_tokenizers[model_name]
def truncate_to_last_words(text, n):
paragraphs = text.split("\n\n")
words = []
para_indices = []
for i, para in enumerate(paragraphs):
for word in para.strip().split():
words.append(word)
para_indices.append(i)
if len(words) <= n:
return text
selected_words = words[-n:]
selected_indices = para_indices[-n:]
reconstructed_paragraphs = {}
for word, idx in zip(selected_words, selected_indices):
if idx not in reconstructed_paragraphs:
reconstructed_paragraphs[idx] = []
reconstructed_paragraphs[idx].append(word)
result = '\n\n'.join(
' '.join(reconstructed_paragraphs[i]) for i in sorted(reconstructed_paragraphs)
)
return result
def generate_text(model_name, input_text, decoding_method, max_length, use_last_words, num_last_words, repetition_penalty):
global stop_generation
stop_generation = False
try:
# Load model
model, tokenizer = load_model(model_name)
# Optional: clean HTML-like input
if '<' in input_text and '>' in input_text:
soup = BeautifulSoup(input_text, "html.parser")
input_text = soup.get_text(separator="\n\n")
if use_last_words:
input_text = truncate_to_last_words(input_text, int(num_last_words))
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
generation_args = {
"max_new_tokens": max_length,
"pad_token_id": tokenizer.eos_token_id,
"do_sample": decoding_method == "stochastic",
"repetition_penalty": repetition_penalty
}
if decoding_method == "stochastic":
generation_args.update({
"top_k": 50,
"top_p": 0.95,
"temperature": 0.7
})
with torch.no_grad():
if stop_generation:
return "Generation stopped by user"
output = model.generate(**inputs, **generation_args)
return tokenizer.decode(output[0], skip_special_tokens=True)
except Exception as e:
return f"Error: {str(e)}"
def stop_generation_fn():
global stop_generation
stop_generation = True
return "Generation cancelled."
# Build Gradio UI
with gr.Blocks(title=title) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
choices=list(MODEL_PATHS.keys()),
value=DEFAULT_MODEL,
label="Model"
)
decoding_method = gr.Radio(
choices=["greedy", "stochastic"],
value="greedy",
label="Decoding Method"
)
max_length = gr.Slider(
minimum=10, maximum=500, value=100, step=10,
label="Max Tokens"
)
repetition_penalty = gr.Slider(
minimum=1.0, maximum=2.0, value=1.2, step=0.1,
label="Repetition Penalty (1.0=no penalty, higher=less repetition)"
)
use_last_words = gr.Checkbox(
label="Use Last N Words",
value=False
)
num_last_words = gr.Number(
label="N Words",
value=20,
minimum=1,
maximum=100,
visible=False
)
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
stop_btn = gr.Button("Stop")
with gr.Column():
input_text = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt...",
lines=5
)
output_text = gr.Textbox(
label="Generated Output",
lines=10,
interactive=False
)
use_last_words.change(
lambda checked: gr.update(visible=checked),
inputs=use_last_words,
outputs=num_last_words
)
generate_btn.click(
fn=generate_text,
inputs=[model_dropdown, input_text, decoding_method, max_length, use_last_words, num_last_words, repetition_penalty],
outputs=output_text
)
stop_btn.click(fn=stop_generation_fn, outputs=output_text, queue=False)
# Launch
if __name__ == "__main__":
demo.launch()