Spaces:
Paused
Paused
| import os | |
| import gradio as gr | |
| from typing import Iterator | |
| from dialog import get_dialog_box | |
| from gateway import check_server_health, request_generation | |
| # CONSTANTS | |
| MAX_NEW_TOKENS: int = 2048 | |
| # GET ENVIRONMENT VARIABLES | |
| CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT") | |
| def toggle_ui(): | |
| """ | |
| Function to toggle the visibility of the UI based on the server health | |
| Returns: | |
| hide/show main ui/dialog | |
| """ | |
| health = check_server_health(cloud_gateway_api=CLOUD_GATEWAY_API) | |
| if health: | |
| return gr.update(visible=True), gr.update(visible=False) # Show main UI, hide dialog | |
| else: | |
| return gr.update(visible=False), gr.update(visible=True) # Hide main UI, show dialog | |
| def generate( | |
| message: str, | |
| chat_history: list, | |
| system_prompt: str, | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.6, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.2, | |
| ) -> Iterator[str]: | |
| """Send a request to backend, fetch the streaming responses and emit to the UI. | |
| Args: | |
| message (str): input message from the user | |
| chat_history (list[tuple[str, str]]): entire chat history of the session | |
| system_prompt (str): system prompt | |
| max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the | |
| prompt. Defaults to 1024. | |
| temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6. | |
| top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities | |
| that add up to top_p or higher are kept for generation. Defaults to 0.9. | |
| top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. | |
| Defaults to 50. | |
| repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty. | |
| Defaults to 1.2. | |
| Yields: | |
| Iterator[str]: Streaming responses to the UI | |
| """ | |
| # sample method to yield responses from the llm model | |
| outputs = [] | |
| for text in request_generation(message=message, | |
| system_prompt=system_prompt, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| cloud_gateway_api=CLOUD_GATEWAY_API): | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| chat_interface = gr.ChatInterface( | |
| fn=generate, | |
| additional_inputs=[ | |
| gr.Textbox(label="System prompt", lines=6), | |
| gr.Slider( | |
| label="Max New Tokens", | |
| minimum=1, | |
| maximum=MAX_NEW_TOKENS, | |
| step=1, | |
| value=1024, | |
| ), | |
| gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=4.0, | |
| step=0.1, | |
| value=0.1, | |
| ), | |
| gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.95, | |
| ), | |
| gr.Slider( | |
| label="Top-k", | |
| minimum=1, | |
| maximum=1000, | |
| step=1, | |
| value=50, | |
| ), | |
| gr.Slider( | |
| label="Repetition penalty", | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.2, | |
| ), | |
| ], | |
| stop_btn=None, | |
| examples=[ | |
| ["Hello there! How are you doing?"], | |
| ["Can you explain briefly to me what is the Python programming language?"], | |
| ["Explain the plot of Cinderella in a sentence."], | |
| ["How many hours does it take a man to eat a Helicopter?"], | |
| ["Write a 100-word article on 'Benefits of Open-Source in AI research'."], | |
| ], | |
| cache_examples=False, | |
| chatbot=gr.Chatbot( | |
| height=600) | |
| ) | |
| with gr.Blocks(css="style.css", theme=gr.themes.Default()) as demo: | |
| # Get the server status before displaying UI | |
| visibility = check_server_health(CLOUD_GATEWAY_API) | |
| # Container for the main interface | |
| with gr.Column(visible=visibility, elem_id="main_ui") as main_ui: | |
| gr.Markdown(f""" | |
| # Llama-3 8B Chat | |
| This Space is an Alpha release that demonstrates [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model running on AMD MI210 infrastructure. The space is built with Meta Llama 3 [License](https://www.llama.com/llama3/license/). Feel free to play with it! | |
| """) | |
| chat_interface.render() | |
| # Dialog box using Markdown for the error message | |
| with gr.Row(visible=(not visibility), elem_id="dialog_box") as dialog_box: | |
| # Add spinner and message | |
| get_dialog_box() | |
| # Timer to check server health every 5 seconds and update UI | |
| timer = gr.Timer(value=10) | |
| timer.tick(fn=toggle_ui, outputs=[main_ui, dialog_box]) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=int(os.getenv("QUEUE")), default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT"))).launch() | |