Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Load tokenizer and model | |
| model_name = "prithivMLmods/rStar-Coder-Qwen3-0.6B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| model.eval() | |
| model = torch.compile(model) | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda") | |
| history = [] | |
| def stream_chat(user_input): | |
| global history | |
| history.append(f"User: {user_input}") | |
| context = "\n".join(history) + "\nBot:" | |
| # Tokenize input | |
| input_ids = tokenizer(context, return_tensors="pt").input_ids | |
| if torch.cuda.is_available(): | |
| input_ids = input_ids.to("cuda") | |
| # Generate token by token | |
| output_ids = input_ids.clone() | |
| bot_reply = "" | |
| max_new_tokens = 200 # adjust as needed | |
| for _ in range(max_new_tokens): | |
| with torch.no_grad(): | |
| outputs = model(output_ids) | |
| next_token_logits = outputs.logits[0, -1, :] | |
| next_token = torch.argmax(next_token_logits).unsqueeze(0) | |
| output_ids = torch.cat([output_ids, next_token.unsqueeze(0)], dim=1) | |
| token_str = tokenizer.decode(next_token) | |
| bot_reply += token_str | |
| # Yield streaming output | |
| yield bot_reply | |
| # Stop if EOS token | |
| if next_token.item() == tokenizer.eos_token_id: | |
| break | |
| history.append(f"Bot: {bot_reply}") | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| chatbot_ui = gr.Chatbot() | |
| msg = gr.Textbox(placeholder="Type a message...") | |
| def respond(user_input, chat_history): | |
| chat_history.append((user_input, "")) | |
| for partial in stream_chat(user_input): | |
| chat_history[-1] = (user_input, partial) | |
| yield chat_history, chat_history | |
| state = gr.State([]) | |
| msg.submit(respond, [msg, state], [chatbot_ui, state]) | |
| demo.launch() | |