Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from rpc import load_index, generate | |
| def setup(): | |
| """ Downloads the memory mapped vector index (~10GB), installs NGT and loads the index""" | |
| import os | |
| import requests | |
| import subprocess | |
| from stream_unzip import stream_unzip | |
| DATASET_URL = os.environ.get("DATASET_URL") | |
| INDEX_TYPE = os.environ.get("INDEX_TYPE") | |
| if not DATASET_URL: | |
| raise ValueError("DATASET_URL must be set in the environment") | |
| extract_dir = "/dev/shm/rpc-vecdb" | |
| os.makedirs(extract_dir, exist_ok=True) | |
| response = requests.get(DATASET_URL, stream=True) | |
| response.raise_for_status() | |
| print("Starting streaming extraction to /dev/shm...") | |
| for filename, filesize, file_iter in stream_unzip(response.iter_content(chunk_size=8192)): | |
| if isinstance(filename, bytes): | |
| filename = filename.decode('utf-8') | |
| file_path = os.path.join(extract_dir, filename) | |
| os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
| with open(file_path, 'wb') as f_out: | |
| for chunk in file_iter: | |
| f_out.write(chunk) | |
| print(f"Extracted: {filename} -> {file_path}") | |
| files = os.listdir(extract_dir) | |
| files = [f for f in files if os.path.isfile(os.path.join(extract_dir, f))] | |
| for f in files: print(f) | |
| print("Index extracted") | |
| if INDEX_TYPE == "ngt": | |
| print("Installing NGT...") | |
| subprocess.check_call(["bash", "install_ngt.sh"]) | |
| print("NGT installed") | |
| print("Loading index...") | |
| if INDEX_TYPE == "ngt": | |
| index_dir = extract_dir + "/index" | |
| else: | |
| index_dir = extract_dir | |
| load_index(index_path=index_dir, idx_type=INDEX_TYPE) | |
| print("Index loaded") | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| user_name, | |
| ai_name, | |
| use_rpc, | |
| max_tokens, | |
| temperature, | |
| ): | |
| prompt = "<s>" | |
| for m in history: | |
| prompt += f"{user_name}: {m[0].strip()}\n{ai_name}: {m[1].strip()}\n" | |
| prompt += f"{user_name}: {message.strip()}\n{ai_name}:" | |
| response = "" | |
| for tok in generate(prompt, use_rpc=use_rpc, max_tokens=max_tokens): | |
| response += tok | |
| yield response | |
| print(history, message, response) | |
| demo = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="Jake", label="User name"), | |
| gr.Textbox(value="Sarah", label="AI name"), | |
| gr.Checkbox( | |
| label="Use RPC", | |
| info="Compare Normal vs. RPC-Enhanced Model", | |
| value=True | |
| ), | |
| gr.Slider(minimum=1, maximum=320, value=128, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=3.0, value=0.2, step=0.1, label="Temperature (only used without RPC)"), | |
| ], | |
| description="Remember that you are talking with a 5M parameter model trained on allenai/soda, not ChatGPT" | |
| ) | |
| if __name__ == "__main__": | |
| setup() | |
| demo.launch() | |