File size: 7,148 Bytes
b376f12
 
 
5facf65
 
c3fb36e
5facf65
c3fb36e
5facf65
 
 
 
 
b376f12
 
 
5facf65
 
 
 
1dfbab6
5facf65
 
 
 
 
 
b376f12
 
 
 
 
 
5facf65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ae9a3e
fed41be
5facf65
1dfbab6
5facf65
 
 
 
 
fed41be
5facf65
 
1dfbab6
5facf65
 
53cb438
7fff69d
5facf65
b376f12
03c2ae6
 
2e11c33
03c2ae6
db24877
03c2ae6
 
 
 
5facf65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e11c33
5facf65
31c9425
5facf65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e11c33
5facf65
 
b376f12
 
 
 
03c2ae6
edd0bac
5facf65
 
edd0bac
5facf65
 
edd0bac
5facf65
 
03c2ae6
edd0bac
 
 
db24877
edd0bac
 
 
 
 
 
 
 
 
 
03c2ae6
 
da09cca
ead1131
cbef7a0
da09cca
 
 
 
3ba38dc
da09cca
19342c6
 
3ba38dc
 
 
 
 
 
da09cca
 
 
03c2ae6
 
d7174fa
03c2ae6
 
 
 
 
da09cca
 
b376f12
c3fb36e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from threading import Thread
from huggingface_hub import hf_hub_download, login
from themes.research_monochrome import ResearchMonochrome
from typing import Iterator, List, Dict

import os
import requests
import json
import subprocess
import gradio as gr

today_date = datetime.today().strftime("%B %-d, %Y")  # noqa: DTZ002

SYS_PROMPT = f"""Today's Date: {today_date}.
You are Granite, developed by IBM. You are a helpful AI assistant"""
TITLE = "IBM Granite 4 Micro served from local GGUF server"
DESCRIPTION = """
<p>Granite 4 Micro is an open-source LLM supporting a 1M context window. This demo uses only 2K context and max 1K output tokens.
<span class="gr_docs_link">
<a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a>
</span>
</p>
"""
LLAMA_CPP_SERVER = "http://127.0.0.1:8081"
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
TOP_P = 0.85
TOP_K = 50
REPETITION_PENALTY = 1.05

# determine platform: CUDA or CPU
try:
    subprocess.run(["nvidia-smi"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
    platform = "CUDA"
except subprocess.CalledProcessError:
    platform = "CPU"
except FileNotFoundError:
    platform = "CPU"

print(f"Detected platform {platform}")

gguf_name = "granite-4.0-h-micro-UD-Q2_K_XL.gguf"
gguf_path = hf_hub_download(
            repo_id="unsloth/granite-4.0-h-micro-GGUF",
            filename=gguf_name,
            local_dir="."
)

# set exe_name depending on platform
exe_name = "llama-server-t3-6266-cuda" if platform == "CUDA" else "llama-server-t3-6268-blas"
exe_path = hf_hub_download(
            repo_id="TobDeBer/Skipper",
            filename=exe_name,
            local_dir="."
)

# start llama-server
subprocess.run(["chmod", "+x", exe_name])
command = ["./" + exe_name, "-m", gguf_name, "-c", "2048", "--port", "8081"]
process = subprocess.Popen(command)
print(f"Llama-server process started with PID {process.pid}")

custom_theme = ResearchMonochrome()
print("Theme type:", type(custom_theme))

def generate(
    message: str,
    chat_history: List[Dict],
    temperature: float = TEMPERATURE,
    repetition_penalty: float = REPETITION_PENALTY,
    top_p: float = TOP_P,
    top_k: float = TOP_K,
    max_new_tokens: int = MAX_NEW_TOKENS,
) -> Iterator[str]:
    """Generate function for chat demo using Llama.cpp server."""

    # Build messages
    conversation = []
    conversation.append({"role": "system", "content": SYS_PROMPT})
    conversation += chat_history
    conversation.append({"role": "user", "content": message})

    # Prepare the prompt for the Llama.cpp server
    prompt = ""
    for item in conversation:
      if item["role"] == "system":
        prompt += f"<|system|>\n{item['content']}\n<|file_separator|>\n"
      elif item["role"] == "user":
        prompt += f"<|user|>\n{item['content']}\n<|file_separator|>\n"
      elif item["role"] == "assistant":
        prompt += f"<|model|>\n{item['content']}\n<|file_separator|>\n"
    prompt += "<|model|>\n"  # Add the beginning token for the assistant


    # Construct the request payload
    payload = {
        "prompt": prompt,
        "stream": True,  # Enable streaming
        "max_tokens": max_new_tokens,
        "temperature": temperature,
        "repeat_penalty": repetition_penalty,
        "top_p": top_p,
        "top_k": top_k,
        "stop": ["<|file_separator|>"], #stops after it sees this
    }

    try:
        # Make the request to the Llama.cpp server
        with requests.post(f"{LLAMA_CPP_SERVER}/completion", json=payload, stream=True, timeout=(30, 300)) as response:
            response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)

            # Stream the response from the server
            outputs = []
            for line in response.iter_lines():
                if line:
                    # Decode the line
                    decoded_line = line.decode('utf-8')
                    # Remove 'data: ' prefix if present
                    if decoded_line.startswith("data: "):
                        decoded_line = decoded_line[6:]

                    # Handle potential JSON decoding errors
                    try:
                        json_data = json.loads(decoded_line)
                        text = json_data.get("content", "")  # Extract content field. crucial.
                        if text:
                            outputs.append(text)
                            yield "".join(outputs)

                    except json.JSONDecodeError:
                        print(f"JSONDecodeError: {decoded_line}")
                        # Handle the error, potentially skipping the line or logging it.

    except requests.exceptions.RequestException as e:
        print(f"Request failed: {e}")
        yield f"Error: {e}"  # Yield an error message to the user
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        yield f"Error: {e}" # Yield error message


css_file_path = Path(Path(__file__).parent / "app.css")

# advanced settings (displayed in Accordion)
temperature_slider = gr.Slider(
    minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"]
)
top_p_slider = gr.Slider(
    minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"]
)
top_k_slider = gr.Slider(
    minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"]
)
repetition_penalty_slider = gr.Slider(
    minimum=0,
    maximum=2.0,
    value=REPETITION_PENALTY,
    step=0.05,
    label="Repetition Penalty",
    elem_classes=["gr_accordion_element"],
)
max_new_tokens_slider = gr.Slider(
    minimum=1,
    maximum=2000,
    value=MAX_NEW_TOKENS,
    step=1,
    label="Max New Tokens",
    elem_classes=["gr_accordion_element"],
)
chat_interface_accordion = gr.Accordion(label="Advanced Settings", open=False)

with gr.Blocks(fill_height=True, css_paths=css_file_path, theme=custom_theme, title=TITLE) as demo:
    gr.HTML(f"<h2>{TITLE}</h2>", elem_classes=["gr_title"])
    gr.HTML(DESCRIPTION)
    chat_interface = gr.ChatInterface(
        fn=generate,
        examples=[
            ["Explain the concept of quantum computing to someone with no background in physics or computer science."],
            ["What is OpenShift?"],
            ["What's the importance of low latency inference?"],
            ["Help me boost productivity habits."],
        ],
        example_labels=[
            "Explain quantum computing",
            "What is OpenShift?",
            "Importance of low latency inference",
            "Boosting productivity habits",
        ],
        cache_examples=False,
        type="messages",
        additional_inputs=[
            temperature_slider,
            repetition_penalty_slider,
            top_p_slider,
            top_k_slider,
            max_new_tokens_slider,
        ],
        additional_inputs_accordion=chat_interface_accordion,
    )

if __name__ == "__main__":
    demo.queue().launch()