import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import os from bs4 import BeautifulSoup # Optional: in case HTML prompt cleanup is needed # Title and description title = "NQZFaizal77 AI Small Space Warfare CLM Test" description = "Casual Language Model Testing Interface For Small Space Warfare Finetuned Model" # Local model paths MODEL_PATHS = { "SwiftStrike Aero Space Warfare Intro 3POV": "nqzfaizal77ai/sa-145m-en-1bc-space-warfare-short-story-3pov-exp-intro", "SwiftStrike Aero Space Warfare 3POV": "nqzfaizal77ai/sa-145m-en-1bc-space-warfare-short-story-3pov-exp", "Halcyra Driftwing Space Warfare Intro 3POV" : "nqzfaizal77ai/hd-178m-en-1bc-space-warfare-short-story-3pov-exp-intro", "Halcyra Driftwing Space Warfare 3POV" : "nqzfaizal77ai/hd-178m-en-1bc-space-warfare-short-story-3pov-exp", "Noble Mind Space Warfare Intro 3POV" : "nqzfaizal77ai/nm-212m-en-1bc-space-warfare-short-story-3pov-exp-intro", "Noble Mind Space Warfare 3POV" : "nqzfaizal77ai/nm-212m-en-1bc-space-warfare-short-story-3pov-exp", "Mirabel Tempest Space Warfare Intro 3POV" : "nqzfaizal77ai/mt-230m-en-1bc-space-warfare-short-story-3pov-exp-intro", "Mirabel Tempest Space Warfare 3POV" : "nqzfaizal77ai/mt-230m-en-1bc-space-warfare-short-story-3pov-exp" } DEFAULT_MODEL = "SwiftStrike Aero Space Warfare Intro 3POV" # Cache loaded models loaded_models = {} loaded_tokenizers = {} # Global stop flag stop_generation = False def load_model(model_name): if model_name not in loaded_models: model_path = MODEL_PATHS[model_name] tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float32, device_map="auto", low_cpu_mem_usage=True ) loaded_models[model_name] = model loaded_tokenizers[model_name] = tokenizer return loaded_models[model_name], loaded_tokenizers[model_name] def truncate_to_last_words(text, n): paragraphs = text.split("\n\n") words = [] para_indices = [] for i, para in enumerate(paragraphs): for word in para.strip().split(): words.append(word) para_indices.append(i) if len(words) <= n: return text selected_words = words[-n:] selected_indices = para_indices[-n:] reconstructed_paragraphs = {} for word, idx in zip(selected_words, selected_indices): if idx not in reconstructed_paragraphs: reconstructed_paragraphs[idx] = [] reconstructed_paragraphs[idx].append(word) result = '\n\n'.join( ' '.join(reconstructed_paragraphs[i]) for i in sorted(reconstructed_paragraphs) ) return result def generate_text(model_name, input_text, decoding_method, max_length, use_last_words, num_last_words, repetition_penalty): global stop_generation stop_generation = False try: # Load model model, tokenizer = load_model(model_name) # Optional: clean HTML-like input if '<' in input_text and '>' in input_text: soup = BeautifulSoup(input_text, "html.parser") input_text = soup.get_text(separator="\n\n") if use_last_words: input_text = truncate_to_last_words(input_text, int(num_last_words)) inputs = tokenizer(input_text, return_tensors="pt").to(model.device) generation_args = { "max_new_tokens": max_length, "pad_token_id": tokenizer.eos_token_id, "do_sample": decoding_method == "stochastic", "repetition_penalty": repetition_penalty } if decoding_method == "stochastic": generation_args.update({ "top_k": 50, "top_p": 0.95, "temperature": 0.7 }) with torch.no_grad(): if stop_generation: return "Generation stopped by user" output = model.generate(**inputs, **generation_args) return tokenizer.decode(output[0], skip_special_tokens=True) except Exception as e: return f"Error: {str(e)}" def stop_generation_fn(): global stop_generation stop_generation = True return "Generation cancelled." # Build Gradio UI with gr.Blocks(title=title) as demo: gr.Markdown(f"# {title}") gr.Markdown(description) with gr.Row(): with gr.Column(): model_dropdown = gr.Dropdown( choices=list(MODEL_PATHS.keys()), value=DEFAULT_MODEL, label="Model" ) decoding_method = gr.Radio( choices=["greedy", "stochastic"], value="greedy", label="Decoding Method" ) max_length = gr.Slider( minimum=10, maximum=500, value=100, step=10, label="Max Tokens" ) repetition_penalty = gr.Slider( minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty (1.0=no penalty, higher=less repetition)" ) use_last_words = gr.Checkbox( label="Use Last N Words", value=False ) num_last_words = gr.Number( label="N Words", value=20, minimum=1, maximum=100, visible=False ) with gr.Row(): generate_btn = gr.Button("Generate", variant="primary") stop_btn = gr.Button("Stop") with gr.Column(): input_text = gr.Textbox( label="Prompt", placeholder="Enter your prompt...", lines=5 ) output_text = gr.Textbox( label="Generated Output", lines=10, interactive=False ) use_last_words.change( lambda checked: gr.update(visible=checked), inputs=use_last_words, outputs=num_last_words ) generate_btn.click( fn=generate_text, inputs=[model_dropdown, input_text, decoding_method, max_length, use_last_words, num_last_words, repetition_penalty], outputs=output_text ) stop_btn.click(fn=stop_generation_fn, outputs=output_text, queue=False) # Launch if __name__ == "__main__": demo.launch()