|
|
import gradio as gr |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
import os |
|
|
from bs4 import BeautifulSoup |
|
|
|
|
|
|
|
|
title = "NQZFaizal77 AI Small Space Warfare CLM Test" |
|
|
description = "Casual Language Model Testing Interface For Small Space Warfare Finetuned Model" |
|
|
|
|
|
|
|
|
MODEL_PATHS = { |
|
|
"SwiftStrike Aero Space Warfare Intro 3POV": "nqzfaizal77ai/sa-145m-en-1bc-space-warfare-short-story-3pov-exp-intro", |
|
|
"SwiftStrike Aero Space Warfare 3POV": "nqzfaizal77ai/sa-145m-en-1bc-space-warfare-short-story-3pov-exp", |
|
|
"Halcyra Driftwing Space Warfare Intro 3POV" : "nqzfaizal77ai/hd-178m-en-1bc-space-warfare-short-story-3pov-exp-intro", |
|
|
"Halcyra Driftwing Space Warfare 3POV" : "nqzfaizal77ai/hd-178m-en-1bc-space-warfare-short-story-3pov-exp", |
|
|
"Noble Mind Space Warfare Intro 3POV" : "nqzfaizal77ai/nm-212m-en-1bc-space-warfare-short-story-3pov-exp-intro", |
|
|
"Noble Mind Space Warfare 3POV" : "nqzfaizal77ai/nm-212m-en-1bc-space-warfare-short-story-3pov-exp", |
|
|
"Mirabel Tempest Space Warfare Intro 3POV" : "nqzfaizal77ai/mt-230m-en-1bc-space-warfare-short-story-3pov-exp-intro", |
|
|
"Mirabel Tempest Space Warfare 3POV" : "nqzfaizal77ai/mt-230m-en-1bc-space-warfare-short-story-3pov-exp" |
|
|
} |
|
|
DEFAULT_MODEL = "SwiftStrike Aero Space Warfare Intro 3POV" |
|
|
|
|
|
|
|
|
loaded_models = {} |
|
|
loaded_tokenizers = {} |
|
|
|
|
|
|
|
|
stop_generation = False |
|
|
|
|
|
def load_model(model_name): |
|
|
if model_name not in loaded_models: |
|
|
model_path = MODEL_PATHS[model_name] |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float32, |
|
|
device_map="auto", |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
loaded_models[model_name] = model |
|
|
loaded_tokenizers[model_name] = tokenizer |
|
|
return loaded_models[model_name], loaded_tokenizers[model_name] |
|
|
|
|
|
def truncate_to_last_words(text, n): |
|
|
paragraphs = text.split("\n\n") |
|
|
words = [] |
|
|
para_indices = [] |
|
|
|
|
|
for i, para in enumerate(paragraphs): |
|
|
for word in para.strip().split(): |
|
|
words.append(word) |
|
|
para_indices.append(i) |
|
|
|
|
|
if len(words) <= n: |
|
|
return text |
|
|
|
|
|
selected_words = words[-n:] |
|
|
selected_indices = para_indices[-n:] |
|
|
|
|
|
reconstructed_paragraphs = {} |
|
|
for word, idx in zip(selected_words, selected_indices): |
|
|
if idx not in reconstructed_paragraphs: |
|
|
reconstructed_paragraphs[idx] = [] |
|
|
reconstructed_paragraphs[idx].append(word) |
|
|
|
|
|
result = '\n\n'.join( |
|
|
' '.join(reconstructed_paragraphs[i]) for i in sorted(reconstructed_paragraphs) |
|
|
) |
|
|
return result |
|
|
|
|
|
def generate_text(model_name, input_text, decoding_method, max_length, use_last_words, num_last_words, repetition_penalty): |
|
|
global stop_generation |
|
|
stop_generation = False |
|
|
|
|
|
try: |
|
|
|
|
|
model, tokenizer = load_model(model_name) |
|
|
|
|
|
|
|
|
if '<' in input_text and '>' in input_text: |
|
|
soup = BeautifulSoup(input_text, "html.parser") |
|
|
input_text = soup.get_text(separator="\n\n") |
|
|
|
|
|
if use_last_words: |
|
|
input_text = truncate_to_last_words(input_text, int(num_last_words)) |
|
|
|
|
|
inputs = tokenizer(input_text, return_tensors="pt").to(model.device) |
|
|
|
|
|
generation_args = { |
|
|
"max_new_tokens": max_length, |
|
|
"pad_token_id": tokenizer.eos_token_id, |
|
|
"do_sample": decoding_method == "stochastic", |
|
|
"repetition_penalty": repetition_penalty |
|
|
} |
|
|
|
|
|
if decoding_method == "stochastic": |
|
|
generation_args.update({ |
|
|
"top_k": 50, |
|
|
"top_p": 0.95, |
|
|
"temperature": 0.7 |
|
|
}) |
|
|
|
|
|
with torch.no_grad(): |
|
|
if stop_generation: |
|
|
return "Generation stopped by user" |
|
|
output = model.generate(**inputs, **generation_args) |
|
|
|
|
|
return tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
def stop_generation_fn(): |
|
|
global stop_generation |
|
|
stop_generation = True |
|
|
return "Generation cancelled." |
|
|
|
|
|
|
|
|
with gr.Blocks(title=title) as demo: |
|
|
gr.Markdown(f"# {title}") |
|
|
gr.Markdown(description) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=list(MODEL_PATHS.keys()), |
|
|
value=DEFAULT_MODEL, |
|
|
label="Model" |
|
|
) |
|
|
|
|
|
decoding_method = gr.Radio( |
|
|
choices=["greedy", "stochastic"], |
|
|
value="greedy", |
|
|
label="Decoding Method" |
|
|
) |
|
|
|
|
|
max_length = gr.Slider( |
|
|
minimum=10, maximum=500, value=100, step=10, |
|
|
label="Max Tokens" |
|
|
) |
|
|
|
|
|
repetition_penalty = gr.Slider( |
|
|
minimum=1.0, maximum=2.0, value=1.2, step=0.1, |
|
|
label="Repetition Penalty (1.0=no penalty, higher=less repetition)" |
|
|
) |
|
|
|
|
|
use_last_words = gr.Checkbox( |
|
|
label="Use Last N Words", |
|
|
value=False |
|
|
) |
|
|
|
|
|
num_last_words = gr.Number( |
|
|
label="N Words", |
|
|
value=20, |
|
|
minimum=1, |
|
|
maximum=100, |
|
|
visible=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
generate_btn = gr.Button("Generate", variant="primary") |
|
|
stop_btn = gr.Button("Stop") |
|
|
|
|
|
with gr.Column(): |
|
|
input_text = gr.Textbox( |
|
|
label="Prompt", |
|
|
placeholder="Enter your prompt...", |
|
|
lines=5 |
|
|
) |
|
|
|
|
|
output_text = gr.Textbox( |
|
|
label="Generated Output", |
|
|
lines=10, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
use_last_words.change( |
|
|
lambda checked: gr.update(visible=checked), |
|
|
inputs=use_last_words, |
|
|
outputs=num_last_words |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_text, |
|
|
inputs=[model_dropdown, input_text, decoding_method, max_length, use_last_words, num_last_words, repetition_penalty], |
|
|
outputs=output_text |
|
|
) |
|
|
|
|
|
stop_btn.click(fn=stop_generation_fn, outputs=output_text, queue=False) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |