rule / app.py
qewrufda's picture
Update app.py
f940641 verified
# app.py
import json
import threading
import gradio as gr
import torch
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from peft import PeftModel
# -----------------------------
# 0. ํ™˜๊ฒฝ ์„ค์ •
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
# -----------------------------
# 1. ๋ชจ๋ธ / ๊ฒฝ๋กœ ์„ค์ •
# -----------------------------
BASE_MODEL = "KORMo-Team/KORMo-10B-sft"
LORA_DIR = "peft_lora"
DOC_PATH = "rule.json"
# -----------------------------
# 2. RAG ๋ฌธ์„œ ๋กœ๋“œ + FAISS ์ค€๋น„
# -----------------------------
with open(DOC_PATH, "r", encoding="utf-8") as f:
documents = json.load(f)
doc_texts = [d["text"] for d in documents]
embedding_model = SentenceTransformer("jhgan/ko-sroberta-multitask", device=device)
doc_embs = embedding_model.encode(doc_texts, convert_to_numpy=True).astype("float32")
index = faiss.IndexFlatL2(doc_embs.shape[1])
index.add(doc_embs)
def retrieve(query, k=3):
q = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
D, I = index.search(q, k)
return [documents[i] for i in I[0]]
print("FAISS ready, docs:", index.ntotal)
# -----------------------------
# 3. ํ† ํฌ๋‚˜์ด์ € + ๋ชจ๋ธ ๋กœ๋“œ (LoRA ํฌํ•จ)
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
model = PeftModel.from_pretrained(
model,
LORA_DIR,
device_map="auto",
torch_dtype=torch.float16
)
model.eval()
print("Model + LoRA loaded")
# -----------------------------
# 4. ํ”„๋กฌํ”„ํŠธ ๋นŒ๋”
# -----------------------------
def build_prompt(persona, instruction, query, retrieved_docs):
context = "\n".join([f"- {d['text']}" for d in retrieved_docs])
return f"""
### ํŽ˜๋ฅด์†Œ๋‚˜:
{persona}
### ์ฐธ๊ณ ์‚ฌํ•ญ:
{instruction}
### ๊ทœ์ •:
{context}
### ์งˆ๋ฌธ:
{query}
### ๋‹ต๋ณ€:
"""
# -----------------------------
# 5. ์ŠคํŠธ๋ฆฌ๋ฐ UI์šฉ
# -----------------------------
def generate_stream(persona, instruction, query, max_new_tokens=256):
retrieved = retrieve(query, k=3)
prompt = build_prompt(persona, instruction, query, retrieved)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
def run_generate():
with torch.no_grad():
model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=0.9,
temperature=0.7,
repetition_penalty=1.2,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
streamer=streamer,
use_cache=True
)
thread = threading.Thread(target=run_generate)
thread.start()
accumulated = ""
for token in streamer:
accumulated += token
yield accumulated
# -----------------------------
# 6. API์šฉ ๋™๊ธฐ ์ƒ์„ฑ
# -----------------------------
def generate_once(persona, instruction, query, max_new_tokens=256):
retrieved = retrieve(query, k=3)
prompt = build_prompt(persona, instruction, query, retrieved)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=0.9,
temperature=0.7,
repetition_penalty=1.2,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text.replace(prompt, "").strip()
# -----------------------------
# 7. ํŽ˜๋ฅด์†Œ๋‚˜ ๊ทธ๋ฃน
# -----------------------------
persona_group = [
("๋‹น์‹ ์€ ์›์น™์„ ์ง€ํ‚ค๋˜ ์ƒํ™ฉ์— ๋”ฐ๋ผ ์œ ์—ฐํ•˜๊ฒŒ ํŒ๋‹จํ•˜๋Š” ์‹œ๊ฐ์„ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค. ๊ฐœ์ธ์˜ ๋Šฅ๋ ฅ๊ณผ ๊ธฐ์—ฌ๋„๋ฅผ ์ค‘์š”ํ•˜๊ฒŒ ์ƒ๊ฐํ•˜์—ฌ ์„ฑ๊ณผ์— ๋”ฐ๋ฅธ ์ฐจ๋“ฑ ๋Œ€์šฐ๋ฅผ ์ •๋‹นํ•˜๋‹ค๊ณ  ํŒ๋‹จํ•˜๋ฉฐ, ๋ณ€ํ™”์™€ ํ˜์‹ ์„ ์ตœ์šฐ์„ ์œผ๋กœ ์—ฌ๊ฒจ ๊ด€์Šต๋ณด๋‹ค ๊ฐœ์„ ์„ ์„ ํƒํ•œ๋‹ค. ๋˜ํ•œ ๋‚ด๋ถ€์— ๋จธ๋ฌด๋ฅด๊ธฐ๋ณด๋‹ค ์™ธ๋ถ€์™€์˜ ์—ฐ๊ณ„์™€ ํ˜‘์—…์„ ์ ๊ทน์ ์œผ๋กœ ์ถ”๊ตฌํ•˜๋ฉฐ, ํ•™์ˆ  ํ™œ๋™๊ณผ ์นœ๋ชฉ ํ™œ๋™์˜ ๊ท ํ˜•์„ ํ†ตํ•ด ๊ฑด๊ฐ•ํ•œ ๊ณต๋™์ฒด ๋ฌธํ™”๋ฅผ ์ง€ํ–ฅํ•œ๋‹ค. ๋Œ€์™ธ์ ์œผ๋กœ ๋ณด์—ฌ์ค„ ์ˆ˜ ์žˆ๋Š” ํ™•์‹คํ•œ ์„ฑ๊ณผ์™€ ์™„์„ฑ๋„๋ฅผ ์ค‘์‹œํ•˜๋ฉด์„œ๋„, ๋‹จ๊ธฐ์  ํ•ด๊ฒฐ๊ณผ ์žฅ๊ธฐ์  ๊ธฐ๋ฐ˜ ๋งˆ๋ จ ์‚ฌ์ด์—์„œ ๊ท ํ˜•์„ ์œ ์ง€ํ•˜๋ ค ๋…ธ๋ ฅํ•œ๋‹ค.", '๋ฐ•์„ธ์—ฐ'),
("๋‹น์‹ ์€ ๊ณต์ •ํ•œ ๊ทœ์น™๊ณผ ์›์น™์„ ์ค‘์‹œํ•˜๋ฉด์„œ, ๊ฐœ์ธ์˜ ์„ฑ๊ณผ์™€ ๋Šฅ๋ ฅ์„ ์ธ์ •ํ•ด ์ฐจ๋“ฑ์„ ๋‘๊ณ  ๋ฐฐ๋ถ„ํ•ฉ๋‹ˆ๋‹ค. ์ „ํ†ต์„ ์กด์ค‘ํ•˜๋˜ ์ ์ง„์ ์ธ ๊ฐœ์„ ์„ ์ˆ˜์šฉํ•˜๋ฉฐ, ๋‚ด๋ถ€ ํ™œ๋™์— ๋จธ๋ฌด๋ฅด์ง€ ์•Š๊ณ  ์™ธ๋ถ€์™€์˜ ํ˜‘์—…๊ณผ ๋„คํŠธ์›Œํฌ๋ฅผ ์ ๊ทน์ ์œผ๋กœ ์ถ”๊ตฌํ•ฉ๋‹ˆ๋‹ค. ํšŒ์› ๊ฐ„ ์œ ๋Œ€์™€ ์ฆ๊ฑฐ์›€์„ ์ค‘์š”์‹œํ•˜๊ณ , ์™„์„ฑ๋„ ๋†’์€ ๊ฒฐ๊ณผ๋ฌผ๊ณผ ๊ณผ์ •์—์„œ์˜ ๋ฐฐ์›€์„ ๋ชจ๋‘ ์ค‘์‹œํ•˜๋ฉฐ, ๋‹น์žฅ์˜ ๋ฌธ์ œ ํ•ด๊ฒฐ๊ณผ ์žฅ๊ธฐ์  ๊ธฐ๋ฐ˜ ๊ตฌ์ถ•์„ ๋™์‹œ์— ๊ณ ๋ คํ•ฉ๋‹ˆ๋‹ค.",'๊น€์ฐฝ์ค€'),
("๊ทœ์œจ๊ณผ ์ž์œจ์˜ ๊ท ํ˜•์„ ์ง€ํ‚ค๋ฉฐ, ๋Šฅ๋ ฅ๊ณผ ์„ฑ๊ณผ๋ฅผ ๊ธฐ์ค€์œผ๋กœ ํŒ๋‹จํ•œ๋‹ค. ์ „ํ†ต์„ ์œ ์ง€ํ•˜๋˜ ์ ์ง„์  ๊ฐœ์„ ์„ ์ถ”๊ตฌํ•˜๊ณ , ์™ธ๋ถ€์™€์˜ ํ˜‘์—…์„ ์ ๊ทน์ ์œผ๋กœ ๋ชจ์ƒ‰ํ•œ๋‹ค. ์ฆ๊ฑฐ์šด ๋ถ„์œ„๊ธฐ ์†์—์„œ ํ•™์Šตํ•˜๋ฉฐ ๊ฐœ์ธ์˜ ์„ฑ์žฅ์„ ์ค‘์‹œํ•˜๊ณ , ๋‹จ๊ธฐ ์„ฑ๊ณผ๋ณด๋‹ค ๋™์•„๋ฆฌ์˜ ์žฅ๊ธฐ์  ๊ธฐ๋ฐ˜์„ ์šฐ์„ ํ•œ๋‹ค.", '์ด์ƒ๊ธฐ'),
("๊ทœ์œจ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•˜์ง€๋งŒ ์œ ์—ฐํ•˜๋ฉฐ, ๋ถ„๋ฐฐ๋Š” ์ค‘๋ฆฝ์ ์ด๊ณ  ๊ฐœ์„ ์„ ์ถ”๊ตฌํ•œ๋‹ค. ์™ธ๋ถ€ ์—ฐ๊ณ„๋ฅผ ์ ๋‹นํžˆ ํ™œ์šฉํ•˜๋ฉฐ ํ•™์ˆ ยท์นœ๋ชฉ ๋ชจ๋‘ ์ƒํ™ฉ์— ๋”ฐ๋ผ ์„ ํƒํ•˜๊ณ , ๊ฐ€์‹œ์„ฑ๊ณผ ์žฅ๊ธฐ ๊ธฐ๋ฐ˜์„ ์กฐํ™”๋กญ๊ฒŒ ๊ณ ๋ คํ•œ๋‹ค.", '์ฑ„ํ›ˆ'),
("์ž์œจ์„ ์กด์ค‘ํ•˜๋˜ ์ตœ์†Œํ•œ์˜ ๊ทœ์œจ์„ ์œ ์ง€ํ•˜๋ฉฐ, ๊ธฐ์—ฌ๋„์™€ ๊ฐœ์„ ์„ ๊ท ํ˜• ์žˆ๊ฒŒ ๋ฐ˜์˜ํ•œ๋‹ค. ๋‚ด๋ถ€์™€ ์™ธ๋ถ€ ํ™œ๋™์„ ์ƒํ™ฉ์— ๋”ฐ๋ผ ์กฐ์ ˆํ•˜๊ณ  ํ•™์ˆ ๊ณผ ์นœ๋ชฉ ๋ชจ๋‘๋ฅผ ํฌ์šฉํ•˜๋ฉฐ, ์„ฑ์žฅ๊ณผ ์žฅ๊ธฐ ๊ธฐ๋ฐ˜์„ ์ค‘์‹œํ•˜๋Š” ์‹ค์šฉ์  ์šด์˜์„ ์„ ํ˜ธํ•œ๋‹ค.", '์šฉ์šฐ'),
("๊ทœ์œจ๊ณผ ๊ณต์ •์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์•ˆ์ •์ ์ธ ์šด์˜์„ ์ถ”๊ตฌํ•˜๋ฉฐ, ๊ท ๋“ฑยท๊ฐœ์„ ยท์นœ๋ชฉ ๊ฐ„์˜ ๊ท ํ˜•์„ ์ค‘์‹œํ•œ๋‹ค. ๋‚ด๋ถ€ ์ค‘์‹ฌ์ด๋˜ ํ•„์š”์— ๋”ฐ๋ผ ์™ธ๋ถ€ ํ˜‘๋ ฅ์„ ์ˆ˜์šฉํ•˜๊ณ , ์„ฑ์žฅ๊ณผ ์žฅ๊ธฐ ๊ธฐ๋ฐ˜์„ ํ•จ๊ป˜ ๊ณ ๋ คํ•˜๋Š” ์‹ค์šฉ์  ํŒ๋‹จ์„ ์ง€ํ–ฅํ•œ๋‹ค.",'ํ˜•์ง„')
]
instruction_text = """
ํ•ด๋‹น ํŽ˜๋ฅด์†Œ๋‚˜์˜ ์„ฑ๊ฒฉ์„ ๊ฐ€์ง„ ์‹ฌํŒ๊ด€์ž…๋‹ˆ๋‹ค.
๋ฐ˜๋“œ์‹œ 3๋ฌธ์žฅ๋งŒ ๋งํ•˜์‹ญ์‹œ์˜ค.
๊ฐ ๋ฌธ์žฅ์€ 30์ž ์ด๋‚ด๋กœ ์ œํ•œํ•ฉ๋‹ˆ๋‹ค.
๊ทœ์ •์— ๊ทผ๊ฑฐํ•˜์—ฌ ๋‹ตํ•˜์‹œ์˜ค.
๋ฐ˜๋ณต ๊ธˆ์ง€, ํŒ๋‹จ ๊ทผ๊ฑฐ ํ•„์ˆ˜.
"""
# -----------------------------
# 8. ์ŠคํŠธ๋ฆฌ๋ฐ UI์šฉ
# -----------------------------
def run_all_streaming(query):
for persona, name in persona_group:
yield f"## ๐Ÿ‘ค {name}\n"
for partial in generate_stream(persona, instruction_text, query):
yield partial
yield "\n\n---\n\n"
# -----------------------------
# 9. API์šฉ ๋™๊ธฐ ์‹คํ–‰ (๋ฌธ์ž์—ด ๋ฐ˜ํ™˜)
# -----------------------------
def run_all_api(query):
out = ""
for persona, name in persona_group:
out += f"## ๐Ÿ‘ค {name}\n"
text = generate_once(persona, instruction_text, query)
out += text + "\n\n---\n\n"
return out
# -----------------------------
# 10. Gradio ์•ฑ ๊ตฌ์„ฑ
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("# ๐Ÿ”ฅ KORMo LoRA + RAG (Streaming UI + API)")
user_input = gr.Textbox(
label="์งˆ๋ฌธ ์ž…๋ ฅ",
value="3๋ฒˆ ์ด์ƒ์˜ ๊ฒฐ์„์„ ํ–ˆ์ง€๋งŒ ์‹ค๋ ฅ์€ ๋›ฐ์–ด๋‚œ ์ •ํšŒ์›์„ ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ• ๊นŒ?"
)
# Streaming UI
output_stream = gr.Markdown()
run_btn = gr.Button("๐Ÿš€ ์‹คํ–‰(Streaming UI)")
run_btn.click(fn=run_all_streaming, inputs=[user_input], outputs=[output_stream])
# API ๋ฒ„ํŠผ (๋™๊ธฐ ๋ฐ˜ํ™˜)
api_output = gr.Textbox(label="API ๋ฐ˜ํ™˜ ๊ฒฐ๊ณผ", lines=15)
run_btn_api = gr.Button("๐Ÿ” ์‹คํ–‰(API)")
run_btn_api.click(fn=run_all_api, inputs=[user_input], outputs=[api_output], api_name="start_api")
# -----------------------------
# 11. Launch
# -----------------------------
demo.launch(share=True)