Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| import torch | |
| import difflib | |
| from threading import Thread | |
| from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer | |
| model_id = "textcleanlm/textcleanlm-1-8b" | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| global model, tokenizer | |
| if model is None: | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| # Add padding token if needed | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Try different model classes | |
| for model_class in [AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoModel]: | |
| try: | |
| model = model_class.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| break | |
| except: | |
| continue | |
| if model is None: | |
| raise ValueError(f"Could not load model {model_id}") | |
| return model, tokenizer | |
| def create_diff_html(original, cleaned): | |
| """Create HTML diff visualization""" | |
| original_lines = original.splitlines(keepends=True) | |
| cleaned_lines = cleaned.splitlines(keepends=True) | |
| differ = difflib.unified_diff(original_lines, cleaned_lines, fromfile='Original', tofile='Cleaned', lineterm='') | |
| html_diff = '<div style="font-family: monospace; font-size: 12px; white-space: pre-wrap;">' | |
| for line in differ: | |
| if line.startswith('+++') or line.startswith('---'): | |
| html_diff += f'<div style="color: #666;">{line}</div>' | |
| elif line.startswith('@@'): | |
| html_diff += f'<div style="color: #0066cc; font-weight: bold;">{line}</div>' | |
| elif line.startswith('+'): | |
| html_diff += f'<div style="background-color: #e6ffed; color: #24292e;">{line}</div>' | |
| elif line.startswith('-'): | |
| html_diff += f'<div style="background-color: #ffeef0; color: #24292e;">{line}</div>' | |
| else: | |
| html_diff += f'<div>{line}</div>' | |
| html_diff += '</div>' | |
| return html_diff | |
| def clean_text(text): | |
| model, tokenizer = load_model() | |
| # Apply chat template | |
| messages = [ | |
| {"role": "user", "content": text} | |
| ] | |
| # Apply the chat template | |
| formatted_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(formatted_text, return_tensors="pt", max_length=4096, truncation=True) | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| # Enable streaming | |
| streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| **inputs, | |
| max_length=4096, | |
| num_beams=1, # Set to 1 for streaming | |
| do_sample=True, | |
| temperature=1.0, | |
| streamer=streamer, | |
| ) | |
| # Run generation in a separate thread | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Collect generated text, skipping the input | |
| generated_text = "" | |
| input_length = len(formatted_text) | |
| full_output = "" | |
| for new_text in streamer: | |
| full_output += new_text | |
| # Only yield the part after the input | |
| if len(full_output) > input_length: | |
| generated_text = full_output[input_length:].strip() | |
| yield generated_text, "" | |
| thread.join() | |
| # After generation is complete, create diff | |
| diff_html = create_diff_html(text, generated_text) | |
| yield generated_text, diff_html | |
| # Create the interface with blocks for better control | |
| with gr.Blocks(title="TextClean-4B Demo") as demo: | |
| gr.Markdown("# TextClean-4B Demo") | |
| gr.Markdown("Simple demo for text cleaning using textcleanlm/textclean-4B model") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| lines=5, | |
| placeholder="Enter text to clean...", | |
| label="Input Text" | |
| ) | |
| submit_btn = gr.Button("Clean Text", variant="primary") | |
| with gr.Row(): | |
| output_text = gr.Textbox( | |
| lines=5, | |
| label="Cleaned Text", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| diff_display = gr.HTML(label="Diff View") | |
| submit_btn.click( | |
| fn=clean_text, | |
| inputs=input_text, | |
| outputs=[output_text, diff_display] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |