Spaces:
Running
on
Zero
Running
on
Zero
| # This space is mostly a copy of the work of Aritra Roy Gosthipaty (see https://huggingface.co/spaces/ariG23498/kv-press/blob/main/app.py) | |
| import spaces | |
| import requests | |
| import gradio as gr | |
| from bs4 import BeautifulSoup | |
| from transformers import pipeline | |
| from kvpress import ( | |
| ExpectedAttentionPress, | |
| KnormPress, | |
| RandomPress, | |
| SnapKVPress, | |
| StreamingLLMPress, | |
| TOVAPress, | |
| ) | |
| press_dict = { | |
| "ExpectedAttentionPress": ExpectedAttentionPress, | |
| "KnormPress": KnormPress, | |
| "RandomPress": RandomPress, | |
| "SnapKVPress": SnapKVPress, | |
| "StreamingLLMPress": StreamingLLMPress, | |
| "TOVAPress": TOVAPress, | |
| } | |
| pipe_dict = dict( | |
| (ckpt, pipeline("kv-press-text-generation", model=ckpt, device="cuda:0", torch_dtype="auto")) | |
| for ckpt in ["meta-llama/Meta-Llama-3.1-8B-Instruct", "Qwen/Qwen2.5-7B-Instruct-1M"] | |
| ) | |
| def process_request(url, question, press_name, pipe_name, compression_ratio): | |
| """ """ | |
| if press_name not in press_dict: | |
| return f"Invalid press selected: {press_name}", -1, -1 | |
| # Fetch the Wikipedia article | |
| try: | |
| headers = {'User-Agent': 'kvpress/1.0 (https://github.com/NVIDIA/kvpress; kvpress@nvidia.com) requests/2.31.0'} | |
| content = requests.get(url, headers=headers).content | |
| except requests.exceptions.RequestException as e: | |
| return f"Error fetching the Wikipedia article: {str(e)}", -1, -1 | |
| try: | |
| # Parse the Wikipedia HTML | |
| soup = BeautifulSoup(content, "html.parser") | |
| context = "".join([p.text for p in soup.find_all("p")]) + "\n\n" | |
| if content == "\n\n": | |
| return f"Error parsing the Wikipedia article", -1, -1 | |
| # Initialize the press | |
| press = press_dict[press_name](compression_ratio) | |
| num_tokens = pipe_dict[pipe_name].tokenizer(context, return_tensors="pt")["input_ids"].shape[1] | |
| pred_answer = pipe_dict[pipe_name](context, question=question, press=press)["answer"] | |
| return pred_answer, num_tokens, int(num_tokens * (1 - compression_ratio)) | |
| except Exception as e: | |
| if "CUDA out of memory" in str(e): | |
| return "Error: CUDA out of memory. Try using a smaller article or a lower compression ratio.", -1 | |
| else: | |
| return str(e), -1, -1 | |
| def gradio_interface(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Wikipedia Article Question Answering with kvpress | |
| This demo answers questions about any given Wikipedia article. | |
| Under the hood, [kvpress](https://github.com/NVIDIA/kvpress) *compresses the key-value (KV) cache* associated with the article, helping reduce memory usage and accelerate decoding. | |
| **How to use:** | |
| 1. Enter a Wikipedia article URL | |
| 2. Type your question | |
| 3. Select a model, a press and the desired compression ratio | |
| 4. Press "Submit" to see the answer, along with token statistics before and after compression | |
| """ | |
| ) | |
| with gr.Row(): | |
| url_input = gr.Textbox(label="Wikipedia Article URL", placeholder="Enter the Wikipedia article URL here") | |
| question_input = gr.Textbox(label="Question", placeholder="Type your question here") | |
| with gr.Row(): | |
| pipe_selector = gr.Dropdown( | |
| choices=list(pipe_dict.keys()), | |
| value="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
| label="Select Model", | |
| ) | |
| press_selector = gr.Dropdown( | |
| choices=list(press_dict.keys()), | |
| value="ExpectedAttentionPress", | |
| label="Select Press", | |
| ) | |
| compression_slider = gr.Slider(minimum=0.0, maximum=0.9, step=0.1, value=0.5, label="Compression Ratio") | |
| output = gr.Textbox(label="Output", lines=10) | |
| output_num_tokens = gr.Number(label="Number of tokens before compression", interactive=False) | |
| output_compressed_num_tokens = gr.Number(label="Number of tokens after compression", interactive=False) | |
| submit_button = gr.Button("Submit") | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "https://en.wikipedia.org/wiki/Nvidia", | |
| "Complete this sentence: In May 2017, the program had 1,300 companies. As of March 2018, there were ", | |
| "ExpectedAttentionPress", | |
| 0.5, | |
| ], | |
| [ | |
| "https://en.wikipedia.org/wiki/Hugging_Face", | |
| "What was the original name of the transformers library ?", | |
| "ExpectedAttentionPress", | |
| 0.5, | |
| ], | |
| [ | |
| "https://en.wikipedia.org/wiki/World_Chess_Championship_2024", | |
| "On which move did the world chess championship end?", | |
| "ExpectedAttentionPress", | |
| 0.5, | |
| ], | |
| ], | |
| inputs=[url_input, question_input, press_selector, compression_slider], | |
| ) | |
| submit_button.click( | |
| process_request, | |
| inputs=[url_input, question_input, press_selector, pipe_selector, compression_slider], | |
| outputs=[output, output_num_tokens, output_compressed_num_tokens], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # Launch demo | |
| demo = gradio_interface() | |
| demo.launch() | |