Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import gc | |
| import subprocess | |
| subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
| from threading import Thread, Event | |
| import time | |
| import uuid | |
| import re | |
| from diffusers import ChromaPipeline | |
| # Pre-load ONLY Chroma (not LLMs, to support custom models) | |
| print("Loading Chroma1-HD...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Device at module level: {device}") | |
| chroma_pipe = ChromaPipeline.from_pretrained( | |
| "lodestones/Chroma1-HD", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| chroma_pipe = chroma_pipe.to(device) | |
| print("✓ Chroma1-HD ready") | |
| MODEL_CONFIGS = { | |
| "Nekochu/Luminia-13B-v3": { | |
| "system": "", | |
| "examples": [ | |
| "### Instruction:\nCreate stable diffusion metadata based on the given english description. Luminia\n\n### Input:\nfavorites and popular SFW", | |
| "### Instruction:\nProvide tips on stable diffusion to optimize low token prompts and enhance quality include prompt example." | |
| ], | |
| "supports_image_gen": True, | |
| "sd_temp": 0.3, | |
| "sd_top_p": 0.8, | |
| "branch": None # Uses main/default branch | |
| }, | |
| "Nekochu/Luminia-8B-v4-Chan": { | |
| "system": "write a response like a 4chan user", | |
| "examples": [], | |
| "supports_image_gen": False, | |
| "branch": "Llama-3-8B-4Chan_SD_QLoRa" | |
| }, | |
| "Nekochu/Luminia-8B-RP": { | |
| "system": "You are a knowledgeable and empathetic mental health professional.", | |
| "examples": ["How to cope with anxiety?"], | |
| "supports_image_gen": False, | |
| "branch": None | |
| } | |
| } | |
| DEFAULT_MODELS = list(MODEL_CONFIGS.keys()) | |
| models_cache = {} | |
| stop_event = Event() | |
| current_thread = None | |
| MAX_CACHE_SIZE = 2 | |
| DEFAULT_MODEL = DEFAULT_MODELS[0] | |
| def parse_model_id(model_id_str): | |
| """Parse model ID and optional branch (format: 'model_id:branch')""" | |
| if ':' in model_id_str: | |
| parts = model_id_str.split(':', 1) | |
| return parts[0], parts[1] | |
| if model_id_str in MODEL_CONFIGS: # Check if it's a known model with a specific branch | |
| config = MODEL_CONFIGS[model_id_str] | |
| return model_id_str, config.get('branch', None) | |
| return model_id_str, None | |
| def parse_sd_metadata(text: str): | |
| """Parse SD metadata""" | |
| metadata = { | |
| 'prompt': '', | |
| 'negative_prompt': '', | |
| 'steps': 25, | |
| 'cfg_scale': 7.0, | |
| 'seed': 42, | |
| 'width': 1024, | |
| 'height': 1024 | |
| } | |
| if not text: | |
| metadata['prompt'] = '(masterpiece, best quality), 1girl' | |
| return metadata | |
| try: | |
| if "Negative prompt:" in text: | |
| parts = text.split("Negative prompt:", 1) | |
| metadata['prompt'] = parts[0].strip().rstrip('.,;')[:500] | |
| if len(parts) > 1: | |
| neg_section = parts[1] | |
| param_match = re.search(r'(Steps:|Sampler:|CFG scale:|Seed:|Size:)', neg_section) | |
| if param_match: | |
| metadata['negative_prompt'] = neg_section[:param_match.start()].strip().rstrip('.,;')[:300] | |
| else: | |
| metadata['negative_prompt'] = neg_section.strip().rstrip('.,;')[:300] | |
| else: | |
| param_match = re.search(r'(Steps:|Sampler:|CFG scale:|Seed:|Size:)', text) | |
| if param_match: | |
| metadata['prompt'] = text[:param_match.start()].strip().rstrip('.,;')[:500] | |
| else: | |
| metadata['prompt'] = text.strip()[:500] | |
| patterns = { | |
| 'Steps': (r'Steps:\s*(\d+)', lambda x: min(int(x), 30)), | |
| 'CFG scale': (r'CFG scale:\s*([\d.]+)', float), | |
| 'Seed': (r'Seed:\s*(\d+)', lambda x: int(x) % (2**32)), | |
| 'Size': (r'Size:\s*(\d+)x(\d+)', None) | |
| } | |
| for key, (pattern, converter) in patterns.items(): | |
| match = re.search(pattern, text) | |
| if match: | |
| try: | |
| if key == 'Size': | |
| metadata['width'] = min(max(int(match.group(1)), 512), 1536) | |
| metadata['height'] = min(max(int(match.group(2)), 512), 1536) | |
| else: | |
| metadata[key.lower().replace(' ', '_')] = converter(match.group(1)) | |
| except: | |
| pass | |
| except: | |
| pass | |
| if not metadata['prompt']: | |
| metadata['prompt'] = '(masterpiece, best quality), 1girl' | |
| return metadata | |
| def clear_old_cache(): | |
| global models_cache | |
| if len(models_cache) >= MAX_CACHE_SIZE: | |
| oldest = min(models_cache.items(), key=lambda x: x[1].get('last_used', 0)) | |
| del models_cache[oldest[0]] | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def generate_text_gpu(model_id_str, message, history, system, temp, top_p, top_k, max_tokens, rep_penalty): | |
| """Text generation with branch support""" | |
| global models_cache, stop_event, current_thread | |
| stop_event.clear() | |
| model_id, branch = parse_model_id(model_id_str) # Parse model ID and branch | |
| cache_key = f"{model_id}:{branch}" if branch else model_id | |
| config = MODEL_CONFIGS.get(model_id, {}) | |
| if "Luminia-13B-v3" in model_id and ("stable diffusion" in message.lower() or "metadata" in message.lower()): | |
| temp = config.get('sd_temp', 0.3) | |
| top_p = config.get('sd_top_p', 0.8) | |
| print(f"Using SD settings: temp={temp}, top_p={top_p}") | |
| if cache_key not in models_cache: | |
| clear_old_cache() | |
| try: | |
| yield history + [[message, f"📥 Loading {model_id}{f' ({branch})' if branch else ''}..."]], "Loading..." | |
| # Load with branch/revision support | |
| load_kwargs = {"trust_remote_code": True} | |
| if branch: | |
| load_kwargs["revision"] = branch | |
| print(f"Loading from branch: {branch}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, **load_kwargs) | |
| tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| model_kwargs = { | |
| "quantization_config": bnb_config, | |
| "device_map": "auto", | |
| "trust_remote_code": True, | |
| "attn_implementation": "flash_attention_2" if torch.cuda.is_available() else None, | |
| "low_cpu_mem_usage": True | |
| } | |
| if branch: | |
| model_kwargs["revision"] = branch | |
| model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) | |
| models_cache[cache_key] = { | |
| "model": model, | |
| "tokenizer": tokenizer, | |
| "last_used": time.time() | |
| } | |
| except Exception as e: | |
| yield history + [[message, f"❌ Failed: {str(e)[:200]}"]], "Error" | |
| return | |
| models_cache[cache_key]['last_used'] = time.time() | |
| model = models_cache[cache_key]["model"] | |
| tokenizer = models_cache[cache_key]["tokenizer"] | |
| prompt = "" | |
| if system: | |
| prompt = f"{system}\n\n" | |
| for user_msg, assistant_msg in history: | |
| if "### Instruction:" in user_msg: | |
| prompt += f"{user_msg}\n### Response:\n{assistant_msg}\n\n" | |
| else: | |
| prompt += f"### Instruction:\n{user_msg}\n\n### Response:\n{assistant_msg}\n\n" | |
| if "### Instruction:" in message and "### Response:" not in message: | |
| prompt += f"{message}\n### Response:\n" | |
| elif "### Instruction:" not in message: | |
| prompt += f"### Instruction:\n{message}\n\n### Response:\n" | |
| else: | |
| prompt += message | |
| print(f"Prompt ending: ...{prompt[-200:]}") | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
| input_tokens = inputs['input_ids'].shape[1] | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| except Exception as e: | |
| yield history + [[message, f"❌ Tokenization failed: {str(e)}"]], "Error" | |
| return | |
| print(f"📝 {input_tokens} tokens | Temp: {temp} | Top-p: {top_p}") | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=5) | |
| gen_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": min(max_tokens, 2048), | |
| "temperature": max(temp, 0.01), | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "repetition_penalty": rep_penalty, | |
| "do_sample": temp > 0.01, | |
| "pad_token_id": tokenizer.pad_token_id | |
| } | |
| current_thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
| current_thread.start() | |
| start_time = time.time() | |
| partial = "" | |
| token_count = 0 | |
| try: | |
| for text in streamer: | |
| if stop_event.is_set(): | |
| break | |
| partial += text | |
| token_count = len(tokenizer.encode(partial, add_special_tokens=False)) | |
| elapsed = time.time() - start_time | |
| if elapsed > 0: | |
| yield history + [[message, partial]], f"⚡ {token_count} @ {token_count/elapsed:.1f} t/s" | |
| except: | |
| pass | |
| finally: | |
| if current_thread.is_alive(): | |
| stop_event.set() | |
| current_thread.join(timeout=2) | |
| final_time = time.time() - start_time | |
| yield history + [[message, partial]], f"✅ {token_count} tokens in {final_time:.1f}s" | |
| def generate_image_gpu(text_output): | |
| """Image generation with pre-loaded Chroma""" | |
| global chroma_pipe | |
| if not text_output or text_output.isspace(): | |
| return None, "❌ No valid text", gr.update(visible=False) | |
| try: | |
| metadata = parse_sd_metadata(text_output) | |
| print(f"Generating: {metadata['width']}x{metadata['height']} | Steps: {metadata['steps']}") | |
| if torch.cuda.is_available(): | |
| chroma_pipe = chroma_pipe.to("cuda") | |
| generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(metadata['seed']) | |
| image = chroma_pipe( | |
| prompt=metadata['prompt'], | |
| negative_prompt=metadata['negative_prompt'], | |
| generator=generator, | |
| num_inference_steps=metadata['steps'], | |
| guidance_scale=metadata['cfg_scale'], | |
| width=metadata['width'], | |
| height=metadata['height'] | |
| ).images[0] | |
| status = f"✅ {metadata['width']}x{metadata['height']} | {metadata['steps']} steps | CFG: {metadata['cfg_scale']} | Seed: {metadata['seed']}" | |
| return image, status, gr.update(visible=False) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"❌ Failed: {str(e)[:200]}", gr.update(visible=False) | |
| def stop_generation(): | |
| global stop_event, current_thread | |
| stop_event.set() | |
| if current_thread and current_thread.is_alive(): | |
| current_thread.join(timeout=2) | |
| return gr.update(visible=True), gr.update(visible=False) | |
| css = """ | |
| #chatbot {height: 305px;} | |
| #input-row {display: flex; gap: 4px;} | |
| #input-box {flex-grow: 1;} | |
| #button-group {display: inline-flex; flex-direction: column; gap: 2px; width: 45px;} | |
| #button-group button {width: 40px; height: 28px; padding: 2px; font-size: 14px;} | |
| #status {font-size: 11px; color: #666; margin-top: 2px;} | |
| #image-output {max-height: 400px; margin-top: 8px;} | |
| #img-loading {font-size: 11px; color: #666; margin: 4px 0;} | |
| """ | |
| with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot(elem_id="chatbot") | |
| with gr.Row(elem_id="input-row"): | |
| msg = gr.Textbox( | |
| label="Instruction", | |
| lines=3, | |
| elem_id="input-box", | |
| value=MODEL_CONFIGS[DEFAULT_MODEL]["examples"][0] if MODEL_CONFIGS[DEFAULT_MODEL]["examples"] else "", | |
| scale=10 | |
| ) | |
| with gr.Column(elem_id="button-group", scale=1, min_width=45): | |
| submit = gr.Button("▶", variant="primary", size="sm") | |
| stop = gr.Button("⏹", variant="stop", size="sm", visible=False) | |
| undo = gr.Button("↩", size="sm") | |
| clear = gr.Button("🗑", size="sm") | |
| status = gr.Markdown("", elem_id="status") | |
| with gr.Row(): | |
| image_btn = gr.Button("🎨 Generate Image using Chroma1-HD", visible=False, variant="secondary") | |
| last_text = gr.Textbox(visible=False) | |
| img_loading = gr.Markdown("", visible=False, elem_id="img-loading") | |
| image_output = gr.Image(visible=False, elem_id="image-output") | |
| image_status = gr.Markdown("", visible=False) | |
| examples = gr.Examples( | |
| examples=[[ex] for ex in MODEL_CONFIGS[DEFAULT_MODEL]["examples"] if ex], | |
| inputs=msg, | |
| label="Examples" | |
| ) | |
| with gr.Column(scale=1): | |
| model = gr.Dropdown( | |
| DEFAULT_MODELS, | |
| value=DEFAULT_MODEL, | |
| label="Model", | |
| allow_custom_value=True, | |
| info="Custom HF ID + optional :branch" | |
| ) | |
| with gr.Accordion("Settings", open=False): | |
| system = gr.Textbox( | |
| label="System Prompt", | |
| value=MODEL_CONFIGS[DEFAULT_MODEL]["system"], | |
| lines=2 | |
| ) | |
| temp = gr.Slider(0.1, 1.0, 0.35, label="Temperature") | |
| top_p = gr.Slider(0.5, 1.0, 0.85, label="Top-p") | |
| top_k = gr.Slider(10, 100, 40, label="Top-k") | |
| rep_penalty = gr.Slider(1.0, 1.5, 1.1, label="Repetition Penalty") | |
| max_tokens = gr.Slider(256, 2048, 1024, label="Max Tokens") | |
| export_btn = gr.Button("💾 Export", size="sm") | |
| export_file = gr.File(visible=False) | |
| def update_ui_on_model_change(model_id_str): | |
| """Update all UI components when model changes""" | |
| model_id, branch = parse_model_id(model_id_str) | |
| config = MODEL_CONFIGS.get(model_id, {"system": "", "examples": [""], "supports_image_gen": False}) | |
| return ( | |
| config["system"], | |
| config["examples"][0] if config["examples"] else "", | |
| gr.update(visible=False), # image_btn | |
| "", # last_text | |
| None, # image_output (clear image) | |
| gr.update(visible=False), # image_output visibility | |
| "", # image_status text | |
| gr.update(visible=False), # image_status visibility | |
| gr.update(visible=False) # img_loading visibility | |
| ) | |
| def check_image_availability(model_id_str, history): | |
| model_id, _ = parse_model_id(model_id_str) | |
| if "Luminia-13B-v3" in model_id and history and len(history) > 0: | |
| return gr.update(visible=True), history[-1][1] | |
| return gr.update(visible=False), "" | |
| submit.click( | |
| lambda: (gr.update(visible=False), gr.update(visible=True)), | |
| None, [submit, stop] | |
| ).then( | |
| generate_text_gpu, | |
| [model, msg, chatbot, system, temp, top_p, top_k, max_tokens, rep_penalty], | |
| [chatbot, status] | |
| ).then( | |
| lambda: (gr.update(visible=True), gr.update(visible=False)), | |
| None, [submit, stop] | |
| ).then( | |
| check_image_availability, | |
| [model, chatbot], | |
| [image_btn, last_text] | |
| ) | |
| stop.click(stop_generation, None, [submit, stop]) | |
| image_btn.click( | |
| lambda: gr.update(value="🎨 Generating...", visible=True), | |
| None, img_loading | |
| ).then( | |
| generate_image_gpu, | |
| last_text, | |
| [image_output, image_status, img_loading] | |
| ).then( | |
| lambda img: (gr.update(visible=img is not None), gr.update(visible=True)), | |
| image_output, | |
| [image_output, image_status] | |
| ) | |
| model.change( | |
| update_ui_on_model_change, | |
| model, | |
| [system, msg, image_btn, last_text, image_output, image_output, image_status, image_status, img_loading] | |
| ) | |
| undo.click( | |
| lambda h: h[:-1] if h else h, | |
| chatbot, chatbot | |
| ).then( | |
| check_image_availability, | |
| [model, chatbot], | |
| [image_btn, last_text] | |
| ) | |
| clear.click( | |
| lambda: ([], "", "", None, "", gr.update(visible=False), "", gr.update(visible=False)), | |
| None, [chatbot, msg, status, image_output, image_status, image_btn, last_text, img_loading] | |
| ) | |
| def export_chat(history): | |
| if not history: | |
| return None | |
| content = "\n\n".join([f"User: {u}\n\nAssistant: {a}" for u, a in history]) | |
| path = f"chat_{uuid.uuid4().hex[:8]}.txt" | |
| with open(path, "w", encoding="utf-8") as f: | |
| f.write(content) | |
| return path | |
| export_btn.click(export_chat, chatbot, export_file).then( | |
| lambda: gr.update(visible=True), None, export_file | |
| ) | |
| demo.queue().launch() |