File size: 1,921 Bytes
5be1376
15776f9
 
5be1376
15776f9
 
 
 
 
 
9132bb2
 
15776f9
 
5be1376
 
 
15776f9
5be1376
 
 
c19dc01
15776f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c19dc01
5be1376
 
 
 
 
 
c19dc01
5be1376
15776f9
 
 
c19dc01
 
 
5be1376
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load tokenizer and model
model_name = "prithivMLmods/rStar-Coder-Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()

model = torch.compile(model)

if torch.cuda.is_available():
    model = model.to("cuda")

history = []

def stream_chat(user_input):
    global history
    history.append(f"User: {user_input}")
    context = "\n".join(history) + "\nBot:"

    # Tokenize input
    input_ids = tokenizer(context, return_tensors="pt").input_ids
    if torch.cuda.is_available():
        input_ids = input_ids.to("cuda")

    # Generate token by token
    output_ids = input_ids.clone()
    bot_reply = ""
    max_new_tokens = 200  # adjust as needed

    for _ in range(max_new_tokens):
        with torch.no_grad():
            outputs = model(output_ids)
            next_token_logits = outputs.logits[0, -1, :]
            next_token = torch.argmax(next_token_logits).unsqueeze(0)
            output_ids = torch.cat([output_ids, next_token.unsqueeze(0)], dim=1)
            token_str = tokenizer.decode(next_token)
            bot_reply += token_str

            # Yield streaming output
            yield bot_reply

            # Stop if EOS token
            if next_token.item() == tokenizer.eos_token_id:
                break

    history.append(f"Bot: {bot_reply}")

# Gradio interface
with gr.Blocks() as demo:
    chatbot_ui = gr.Chatbot()
    msg = gr.Textbox(placeholder="Type a message...")

    def respond(user_input, chat_history):
        chat_history.append((user_input, ""))
        for partial in stream_chat(user_input):
            chat_history[-1] = (user_input, partial)
            yield chat_history, chat_history

    state = gr.State([])
    msg.submit(respond, [msg, state], [chatbot_ui, state])

demo.launch()