Nekochu's picture
Update app.py
7282a05 verified
import os
import gc
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from threading import Thread, Event
import time
import uuid
import re
from diffusers import ChromaPipeline
# Pre-load ONLY Chroma (not LLMs, to support custom models)
print("Loading Chroma1-HD...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device at module level: {device}")
chroma_pipe = ChromaPipeline.from_pretrained(
"lodestones/Chroma1-HD",
torch_dtype=torch.bfloat16
)
chroma_pipe = chroma_pipe.to(device)
print("✓ Chroma1-HD ready")
MODEL_CONFIGS = {
"Nekochu/Luminia-13B-v3": {
"system": "",
"examples": [
"### Instruction:\nCreate stable diffusion metadata based on the given english description. Luminia\n\n### Input:\nfavorites and popular SFW",
"### Instruction:\nProvide tips on stable diffusion to optimize low token prompts and enhance quality include prompt example."
],
"supports_image_gen": True,
"sd_temp": 0.3,
"sd_top_p": 0.8,
"branch": None # Uses main/default branch
},
"Nekochu/Luminia-8B-v4-Chan": {
"system": "write a response like a 4chan user",
"examples": [],
"supports_image_gen": False,
"branch": "Llama-3-8B-4Chan_SD_QLoRa"
},
"Nekochu/Luminia-8B-RP": {
"system": "You are a knowledgeable and empathetic mental health professional.",
"examples": ["How to cope with anxiety?"],
"supports_image_gen": False,
"branch": None
}
}
DEFAULT_MODELS = list(MODEL_CONFIGS.keys())
models_cache = {}
stop_event = Event()
current_thread = None
MAX_CACHE_SIZE = 2
DEFAULT_MODEL = DEFAULT_MODELS[0]
def parse_model_id(model_id_str):
"""Parse model ID and optional branch (format: 'model_id:branch')"""
if ':' in model_id_str:
parts = model_id_str.split(':', 1)
return parts[0], parts[1]
if model_id_str in MODEL_CONFIGS: # Check if it's a known model with a specific branch
config = MODEL_CONFIGS[model_id_str]
return model_id_str, config.get('branch', None)
return model_id_str, None
def parse_sd_metadata(text: str):
"""Parse SD metadata"""
metadata = {
'prompt': '',
'negative_prompt': '',
'steps': 25,
'cfg_scale': 7.0,
'seed': 42,
'width': 1024,
'height': 1024
}
if not text:
metadata['prompt'] = '(masterpiece, best quality), 1girl'
return metadata
try:
if "Negative prompt:" in text:
parts = text.split("Negative prompt:", 1)
metadata['prompt'] = parts[0].strip().rstrip('.,;')[:500]
if len(parts) > 1:
neg_section = parts[1]
param_match = re.search(r'(Steps:|Sampler:|CFG scale:|Seed:|Size:)', neg_section)
if param_match:
metadata['negative_prompt'] = neg_section[:param_match.start()].strip().rstrip('.,;')[:300]
else:
metadata['negative_prompt'] = neg_section.strip().rstrip('.,;')[:300]
else:
param_match = re.search(r'(Steps:|Sampler:|CFG scale:|Seed:|Size:)', text)
if param_match:
metadata['prompt'] = text[:param_match.start()].strip().rstrip('.,;')[:500]
else:
metadata['prompt'] = text.strip()[:500]
patterns = {
'Steps': (r'Steps:\s*(\d+)', lambda x: min(int(x), 30)),
'CFG scale': (r'CFG scale:\s*([\d.]+)', float),
'Seed': (r'Seed:\s*(\d+)', lambda x: int(x) % (2**32)),
'Size': (r'Size:\s*(\d+)x(\d+)', None)
}
for key, (pattern, converter) in patterns.items():
match = re.search(pattern, text)
if match:
try:
if key == 'Size':
metadata['width'] = min(max(int(match.group(1)), 512), 1536)
metadata['height'] = min(max(int(match.group(2)), 512), 1536)
else:
metadata[key.lower().replace(' ', '_')] = converter(match.group(1))
except:
pass
except:
pass
if not metadata['prompt']:
metadata['prompt'] = '(masterpiece, best quality), 1girl'
return metadata
def clear_old_cache():
global models_cache
if len(models_cache) >= MAX_CACHE_SIZE:
oldest = min(models_cache.items(), key=lambda x: x[1].get('last_used', 0))
del models_cache[oldest[0]]
gc.collect()
torch.cuda.empty_cache()
@spaces.GPU(duration=119)
def generate_text_gpu(model_id_str, message, history, system, temp, top_p, top_k, max_tokens, rep_penalty):
"""Text generation with branch support"""
global models_cache, stop_event, current_thread
stop_event.clear()
model_id, branch = parse_model_id(model_id_str) # Parse model ID and branch
cache_key = f"{model_id}:{branch}" if branch else model_id
config = MODEL_CONFIGS.get(model_id, {})
if "Luminia-13B-v3" in model_id and ("stable diffusion" in message.lower() or "metadata" in message.lower()):
temp = config.get('sd_temp', 0.3)
top_p = config.get('sd_top_p', 0.8)
print(f"Using SD settings: temp={temp}, top_p={top_p}")
if cache_key not in models_cache:
clear_old_cache()
try:
yield history + [[message, f"📥 Loading {model_id}{f' ({branch})' if branch else ''}..."]], "Loading..."
# Load with branch/revision support
load_kwargs = {"trust_remote_code": True}
if branch:
load_kwargs["revision"] = branch
print(f"Loading from branch: {branch}")
tokenizer = AutoTokenizer.from_pretrained(model_id, **load_kwargs)
tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
model_kwargs = {
"quantization_config": bnb_config,
"device_map": "auto",
"trust_remote_code": True,
"attn_implementation": "flash_attention_2" if torch.cuda.is_available() else None,
"low_cpu_mem_usage": True
}
if branch:
model_kwargs["revision"] = branch
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
models_cache[cache_key] = {
"model": model,
"tokenizer": tokenizer,
"last_used": time.time()
}
except Exception as e:
yield history + [[message, f"❌ Failed: {str(e)[:200]}"]], "Error"
return
models_cache[cache_key]['last_used'] = time.time()
model = models_cache[cache_key]["model"]
tokenizer = models_cache[cache_key]["tokenizer"]
prompt = ""
if system:
prompt = f"{system}\n\n"
for user_msg, assistant_msg in history:
if "### Instruction:" in user_msg:
prompt += f"{user_msg}\n### Response:\n{assistant_msg}\n\n"
else:
prompt += f"### Instruction:\n{user_msg}\n\n### Response:\n{assistant_msg}\n\n"
if "### Instruction:" in message and "### Response:" not in message:
prompt += f"{message}\n### Response:\n"
elif "### Instruction:" not in message:
prompt += f"### Instruction:\n{message}\n\n### Response:\n"
else:
prompt += message
print(f"Prompt ending: ...{prompt[-200:]}")
try:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
input_tokens = inputs['input_ids'].shape[1]
inputs = {k: v.to(model.device) for k, v in inputs.items()}
except Exception as e:
yield history + [[message, f"❌ Tokenization failed: {str(e)}"]], "Error"
return
print(f"📝 {input_tokens} tokens | Temp: {temp} | Top-p: {top_p}")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=5)
gen_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": min(max_tokens, 2048),
"temperature": max(temp, 0.01),
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": rep_penalty,
"do_sample": temp > 0.01,
"pad_token_id": tokenizer.pad_token_id
}
current_thread = Thread(target=model.generate, kwargs=gen_kwargs)
current_thread.start()
start_time = time.time()
partial = ""
token_count = 0
try:
for text in streamer:
if stop_event.is_set():
break
partial += text
token_count = len(tokenizer.encode(partial, add_special_tokens=False))
elapsed = time.time() - start_time
if elapsed > 0:
yield history + [[message, partial]], f"⚡ {token_count} @ {token_count/elapsed:.1f} t/s"
except:
pass
finally:
if current_thread.is_alive():
stop_event.set()
current_thread.join(timeout=2)
final_time = time.time() - start_time
yield history + [[message, partial]], f"✅ {token_count} tokens in {final_time:.1f}s"
@spaces.GPU()
def generate_image_gpu(text_output):
"""Image generation with pre-loaded Chroma"""
global chroma_pipe
if not text_output or text_output.isspace():
return None, "❌ No valid text", gr.update(visible=False)
try:
metadata = parse_sd_metadata(text_output)
print(f"Generating: {metadata['width']}x{metadata['height']} | Steps: {metadata['steps']}")
if torch.cuda.is_available():
chroma_pipe = chroma_pipe.to("cuda")
generator = torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(metadata['seed'])
image = chroma_pipe(
prompt=metadata['prompt'],
negative_prompt=metadata['negative_prompt'],
generator=generator,
num_inference_steps=metadata['steps'],
guidance_scale=metadata['cfg_scale'],
width=metadata['width'],
height=metadata['height']
).images[0]
status = f"✅ {metadata['width']}x{metadata['height']} | {metadata['steps']} steps | CFG: {metadata['cfg_scale']} | Seed: {metadata['seed']}"
return image, status, gr.update(visible=False)
except Exception as e:
import traceback
traceback.print_exc()
return None, f"❌ Failed: {str(e)[:200]}", gr.update(visible=False)
def stop_generation():
global stop_event, current_thread
stop_event.set()
if current_thread and current_thread.is_alive():
current_thread.join(timeout=2)
return gr.update(visible=True), gr.update(visible=False)
css = """
#chatbot {height: 305px;}
#input-row {display: flex; gap: 4px;}
#input-box {flex-grow: 1;}
#button-group {display: inline-flex; flex-direction: column; gap: 2px; width: 45px;}
#button-group button {width: 40px; height: 28px; padding: 2px; font-size: 14px;}
#status {font-size: 11px; color: #666; margin-top: 2px;}
#image-output {max-height: 400px; margin-top: 8px;}
#img-loading {font-size: 11px; color: #666; margin: 4px 0;}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(elem_id="chatbot")
with gr.Row(elem_id="input-row"):
msg = gr.Textbox(
label="Instruction",
lines=3,
elem_id="input-box",
value=MODEL_CONFIGS[DEFAULT_MODEL]["examples"][0] if MODEL_CONFIGS[DEFAULT_MODEL]["examples"] else "",
scale=10
)
with gr.Column(elem_id="button-group", scale=1, min_width=45):
submit = gr.Button("▶", variant="primary", size="sm")
stop = gr.Button("⏹", variant="stop", size="sm", visible=False)
undo = gr.Button("↩", size="sm")
clear = gr.Button("🗑", size="sm")
status = gr.Markdown("", elem_id="status")
with gr.Row():
image_btn = gr.Button("🎨 Generate Image using Chroma1-HD", visible=False, variant="secondary")
last_text = gr.Textbox(visible=False)
img_loading = gr.Markdown("", visible=False, elem_id="img-loading")
image_output = gr.Image(visible=False, elem_id="image-output")
image_status = gr.Markdown("", visible=False)
examples = gr.Examples(
examples=[[ex] for ex in MODEL_CONFIGS[DEFAULT_MODEL]["examples"] if ex],
inputs=msg,
label="Examples"
)
with gr.Column(scale=1):
model = gr.Dropdown(
DEFAULT_MODELS,
value=DEFAULT_MODEL,
label="Model",
allow_custom_value=True,
info="Custom HF ID + optional :branch"
)
with gr.Accordion("Settings", open=False):
system = gr.Textbox(
label="System Prompt",
value=MODEL_CONFIGS[DEFAULT_MODEL]["system"],
lines=2
)
temp = gr.Slider(0.1, 1.0, 0.35, label="Temperature")
top_p = gr.Slider(0.5, 1.0, 0.85, label="Top-p")
top_k = gr.Slider(10, 100, 40, label="Top-k")
rep_penalty = gr.Slider(1.0, 1.5, 1.1, label="Repetition Penalty")
max_tokens = gr.Slider(256, 2048, 1024, label="Max Tokens")
export_btn = gr.Button("💾 Export", size="sm")
export_file = gr.File(visible=False)
def update_ui_on_model_change(model_id_str):
"""Update all UI components when model changes"""
model_id, branch = parse_model_id(model_id_str)
config = MODEL_CONFIGS.get(model_id, {"system": "", "examples": [""], "supports_image_gen": False})
return (
config["system"],
config["examples"][0] if config["examples"] else "",
gr.update(visible=False), # image_btn
"", # last_text
None, # image_output (clear image)
gr.update(visible=False), # image_output visibility
"", # image_status text
gr.update(visible=False), # image_status visibility
gr.update(visible=False) # img_loading visibility
)
def check_image_availability(model_id_str, history):
model_id, _ = parse_model_id(model_id_str)
if "Luminia-13B-v3" in model_id and history and len(history) > 0:
return gr.update(visible=True), history[-1][1]
return gr.update(visible=False), ""
submit.click(
lambda: (gr.update(visible=False), gr.update(visible=True)),
None, [submit, stop]
).then(
generate_text_gpu,
[model, msg, chatbot, system, temp, top_p, top_k, max_tokens, rep_penalty],
[chatbot, status]
).then(
lambda: (gr.update(visible=True), gr.update(visible=False)),
None, [submit, stop]
).then(
check_image_availability,
[model, chatbot],
[image_btn, last_text]
)
stop.click(stop_generation, None, [submit, stop])
image_btn.click(
lambda: gr.update(value="🎨 Generating...", visible=True),
None, img_loading
).then(
generate_image_gpu,
last_text,
[image_output, image_status, img_loading]
).then(
lambda img: (gr.update(visible=img is not None), gr.update(visible=True)),
image_output,
[image_output, image_status]
)
model.change(
update_ui_on_model_change,
model,
[system, msg, image_btn, last_text, image_output, image_output, image_status, image_status, img_loading]
)
undo.click(
lambda h: h[:-1] if h else h,
chatbot, chatbot
).then(
check_image_availability,
[model, chatbot],
[image_btn, last_text]
)
clear.click(
lambda: ([], "", "", None, "", gr.update(visible=False), "", gr.update(visible=False)),
None, [chatbot, msg, status, image_output, image_status, image_btn, last_text, img_loading]
)
def export_chat(history):
if not history:
return None
content = "\n\n".join([f"User: {u}\n\nAssistant: {a}" for u, a in history])
path = f"chat_{uuid.uuid4().hex[:8]}.txt"
with open(path, "w", encoding="utf-8") as f:
f.write(content)
return path
export_btn.click(export_chat, chatbot, export_file).then(
lambda: gr.update(visible=True), None, export_file
)
demo.queue().launch()