Spaces:
Paused
Paused
File size: 1,921 Bytes
5be1376 15776f9 5be1376 15776f9 9132bb2 15776f9 5be1376 15776f9 5be1376 c19dc01 15776f9 c19dc01 5be1376 c19dc01 5be1376 15776f9 c19dc01 5be1376 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load tokenizer and model
model_name = "prithivMLmods/rStar-Coder-Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
model = torch.compile(model)
if torch.cuda.is_available():
model = model.to("cuda")
history = []
def stream_chat(user_input):
global history
history.append(f"User: {user_input}")
context = "\n".join(history) + "\nBot:"
# Tokenize input
input_ids = tokenizer(context, return_tensors="pt").input_ids
if torch.cuda.is_available():
input_ids = input_ids.to("cuda")
# Generate token by token
output_ids = input_ids.clone()
bot_reply = ""
max_new_tokens = 200 # adjust as needed
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(output_ids)
next_token_logits = outputs.logits[0, -1, :]
next_token = torch.argmax(next_token_logits).unsqueeze(0)
output_ids = torch.cat([output_ids, next_token.unsqueeze(0)], dim=1)
token_str = tokenizer.decode(next_token)
bot_reply += token_str
# Yield streaming output
yield bot_reply
# Stop if EOS token
if next_token.item() == tokenizer.eos_token_id:
break
history.append(f"Bot: {bot_reply}")
# Gradio interface
with gr.Blocks() as demo:
chatbot_ui = gr.Chatbot()
msg = gr.Textbox(placeholder="Type a message...")
def respond(user_input, chat_history):
chat_history.append((user_input, ""))
for partial in stream_chat(user_input):
chat_history[-1] = (user_input, partial)
yield chat_history, chat_history
state = gr.State([])
msg.submit(respond, [msg, state], [chatbot_ui, state])
demo.launch()
|