Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Any, Generator
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import tarfile
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import lancedb
|
| 9 |
+
from lancedb.embeddings import get_registry
|
| 10 |
+
from huggingface_hub.file_download import hf_hub_download
|
| 11 |
+
from huggingface_hub import InferenceClient
|
| 12 |
+
from transformers import AutoTokenizer
|
| 13 |
+
import gradio as gr
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class Settings:
|
| 18 |
+
"""Settings class to store useful variables for the App.
|
| 19 |
+
"""
|
| 20 |
+
LANCEDB: str = "lancedb"
|
| 21 |
+
LANCEDB_FILE_TAR: str = "lancedb.tar.gz"
|
| 22 |
+
TOKEN: str = os.getenv("HF_API_TOKEN")
|
| 23 |
+
LOCAL_DIR: Path = Path.home() / ".cache/argilla_sdk_docs_db"
|
| 24 |
+
REPO_ID: str = "plaguss/argilla_sdk_docs_queries"
|
| 25 |
+
TABLE_NAME: str = "docs"
|
| 26 |
+
MODEL_NAME: str = "plaguss/bge-base-argilla-sdk-matryoshka"
|
| 27 |
+
DEVICE: str = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
+
MODEL_ID: str = "meta-llama/Meta-Llama-3-70B-Instruct"
|
| 29 |
+
|
| 30 |
+
settings = Settings()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def untar_file(source: Path) -> Path:
|
| 34 |
+
"""Untar and decompress files which have passed by `make_tarfile`.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
source (Path): Path pointing to a .tag.gz file.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
filename (Path): The filename of the file decompressed.
|
| 41 |
+
"""
|
| 42 |
+
new_filename = source.parent / source.stem.replace(".tar", "")
|
| 43 |
+
with tarfile.open(source, "r:gz") as f:
|
| 44 |
+
f.extractall(source.parent)
|
| 45 |
+
return new_filename
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def download_database(
|
| 49 |
+
repo_id: str,
|
| 50 |
+
lancedb_file: str = "lancedb.tar.gz",
|
| 51 |
+
local_dir: Path = Path.home() / ".cache/argilla_sdk_docs_db",
|
| 52 |
+
token: str = os.getenv("HF_API_TOKEN")
|
| 53 |
+
) -> Path:
|
| 54 |
+
"""Helper function to download the database. Will download a compressed lancedb stored
|
| 55 |
+
in a Hugging Face repository.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
repo_id: Name of the repository where the databsase file is stored.
|
| 59 |
+
lancedb_file: Name of the compressed file containing the lancedb database.
|
| 60 |
+
Defaults to "lancedb.tar.gz".
|
| 61 |
+
local_dir: Path where the file will be donwloaded to. Defaults to
|
| 62 |
+
Path.home()/".cache/argilla_sdk_docs_db".
|
| 63 |
+
token: Token for the Hugging Face hub API. Defaults to os.getenv("HF_API_TOKEN").
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
The path pointing to the database already uncompressed and ready to be used.
|
| 67 |
+
"""
|
| 68 |
+
lancedb_download = Path(
|
| 69 |
+
hf_hub_download(
|
| 70 |
+
repo_id,
|
| 71 |
+
lancedb_file,
|
| 72 |
+
repo_type="dataset",
|
| 73 |
+
token=token,
|
| 74 |
+
local_dir=local_dir
|
| 75 |
+
)
|
| 76 |
+
)
|
| 77 |
+
return untar_file(lancedb_download)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Get the model to create the embeddings
|
| 81 |
+
model = get_registry().get("sentence-transformers").create(name=settings.MODEL_NAME, device=settings.DEVICE)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Database:
|
| 85 |
+
"""Interaction with the vector database to retrieve the chunks.
|
| 86 |
+
|
| 87 |
+
On instantiation, will donwload the lancedb database if nos already found in
|
| 88 |
+
the expected location. Once ready, the only functionality available is
|
| 89 |
+
to retrieve the doc chunks to be used as examples for the LLM.
|
| 90 |
+
"""
|
| 91 |
+
def __init__(self, settings: Settings) -> None:
|
| 92 |
+
self.settings = settings
|
| 93 |
+
self._table: lancedb.table.LanceTable = self.get_table_from_db()
|
| 94 |
+
|
| 95 |
+
def get_table_from_db(self) -> lancedb.table.LanceTable:
|
| 96 |
+
"""Downloads the database containing the embedded docs.
|
| 97 |
+
|
| 98 |
+
If the file is not found in the expected location, will download it, and
|
| 99 |
+
then create the connection, open the table and pass it.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
The table of the database containing the embedded chunks.
|
| 103 |
+
"""
|
| 104 |
+
lancedb_db_path = self.settings.LOCAL_DIR / self.settings.LANCEDB
|
| 105 |
+
|
| 106 |
+
if not lancedb_db_path.exists():
|
| 107 |
+
lancedb_db_path = download_database(
|
| 108 |
+
self.settings.REPO_ID,
|
| 109 |
+
lancedb_file=self.settings.LANCEDB_FILE_TAR,
|
| 110 |
+
local_dir=self.settings.LOCAL_DIR,
|
| 111 |
+
token=self.settings.TOKEN
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
db = lancedb.connect(str(lancedb_db_path))
|
| 115 |
+
table = db.open_table(self.settings.TABLE_NAME)
|
| 116 |
+
return table
|
| 117 |
+
|
| 118 |
+
def retrieve_doc_chunks(self, query: str, limit: int = 12, hard_limit: int = 4) -> str:
|
| 119 |
+
"""Search for similar queries in the database, and return a list with
|
| 120 |
+
|
| 121 |
+
TODO: SPLIT IN TWO SEPARATE FUNCTIONS TO PREPARE THE CONTEXT.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
query (str): _description_
|
| 125 |
+
limit (int, optional): _description_. Defaults to 12.
|
| 126 |
+
hard_limit (int, optional): _description_. Defaults to 4.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
str: _description_
|
| 130 |
+
"""
|
| 131 |
+
# Embed the query to use our custom model instead of the default one.
|
| 132 |
+
embedded_query = model.generate_embeddings([query])
|
| 133 |
+
field_to_retrieve = "text"
|
| 134 |
+
retrieved = (
|
| 135 |
+
self._table
|
| 136 |
+
.search(embedded_query[0])
|
| 137 |
+
.metric("cosine")
|
| 138 |
+
.limit(limit)
|
| 139 |
+
.select([field_to_retrieve]) # Just grab the chunk to use for context
|
| 140 |
+
.to_list()
|
| 141 |
+
)
|
| 142 |
+
# We have repeated questions (up to 4) for a given chunk, so we may get repeated chunks.
|
| 143 |
+
# Request more than necessary and filter them afterwards
|
| 144 |
+
responses = []
|
| 145 |
+
unique_responses = set()
|
| 146 |
+
|
| 147 |
+
for item in retrieved:
|
| 148 |
+
chunk = item["text"]
|
| 149 |
+
if chunk not in unique_responses:
|
| 150 |
+
unique_responses.add(chunk)
|
| 151 |
+
responses.append(chunk)
|
| 152 |
+
|
| 153 |
+
context = ""
|
| 154 |
+
for i, item in enumerate(responses[:hard_limit]):
|
| 155 |
+
if i > 0:
|
| 156 |
+
context += "\n\n"
|
| 157 |
+
context += f"---\n{item}"
|
| 158 |
+
return context
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
database = Database(settings=settings)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def get_client_and_tokenizer(
|
| 165 |
+
model_id: str = settings.MODEL_ID,
|
| 166 |
+
tokenizer_id: Optional[str] = None
|
| 167 |
+
) -> tuple[InferenceClient, AutoTokenizer]:
|
| 168 |
+
"""Obtains the inference client and the tokenizer corresponding to the model.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
model_id: The name of the model. Currently it must be one in the free tier.
|
| 172 |
+
Defaults to "meta-llama/Meta-Llama-3-70B-Instruct".
|
| 173 |
+
tokenizer_id: The name of the corresponding tokenizer. Defaults to None,
|
| 174 |
+
in which case it will use the same as the `model_id`.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
The client and tokenizer chosen.
|
| 178 |
+
"""
|
| 179 |
+
if tokenizer_id is None:
|
| 180 |
+
tokenizer_id = model_id
|
| 181 |
+
|
| 182 |
+
client = InferenceClient()
|
| 183 |
+
base_url = client._resolve_url(
|
| 184 |
+
model=model_id, task="text-generation"
|
| 185 |
+
)
|
| 186 |
+
# Note: We could move to the AsyncClient
|
| 187 |
+
client = InferenceClient(
|
| 188 |
+
model=base_url,
|
| 189 |
+
token=os.getenv("HF_API_TOKEN")
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
| 193 |
+
return client, tokenizer
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
client_kwargs = {
|
| 197 |
+
"stream": True,
|
| 198 |
+
"max_new_tokens": 512,
|
| 199 |
+
"do_sample": False,
|
| 200 |
+
"typical_p": None,
|
| 201 |
+
"repetition_penalty": None,
|
| 202 |
+
"temperature": 0.3,
|
| 203 |
+
"top_p": None,
|
| 204 |
+
"top_k": None,
|
| 205 |
+
"stop_sequences": ["<|eot_id|>", "<|end_of_text|>"] if settings.MODEL_ID.startswith("meta-llama/Meta-Llama-3") else None,
|
| 206 |
+
"seed": None,
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
client, tokenizer = get_client_and_tokenizer()
|
| 211 |
+
|
| 212 |
+
SYSTEM_PROMPT = """\
|
| 213 |
+
You are a support expert in Argilla SDK, whose goal is help users with their questions.
|
| 214 |
+
As a trustworthy expert, you must provide truthful answers to questions using only the provided documentation snippets, not prior knowledge.
|
| 215 |
+
Here are guidelines you must follow when responding to user questions:
|
| 216 |
+
|
| 217 |
+
##Purpose and Functionality**
|
| 218 |
+
- Answer questions related to the Argilla SDK.
|
| 219 |
+
- Provide clear and concise explanations, relevant code snippets, and guidance depending on the user's question and intent.
|
| 220 |
+
- Ensure users succeed in effectively understanding and using Argilla's features.
|
| 221 |
+
- Provide accurate responses to the user's questions.
|
| 222 |
+
|
| 223 |
+
**Specificity**
|
| 224 |
+
- Be specific and provide details only when required.
|
| 225 |
+
- Where necessary, ask clarifying questions to better understand the user's question.
|
| 226 |
+
- Provide accurate and context-specific code excerpts with clear explanations.
|
| 227 |
+
- Ensure the code snippets are syntactically correct, functional, and run without errors.
|
| 228 |
+
- For code troubleshooting-related questions, focus on the code snippet and clearly explain the issue and how to resolve it.
|
| 229 |
+
- Avoid boilerplate code such as imports, installs, etc.
|
| 230 |
+
|
| 231 |
+
**Reliability**
|
| 232 |
+
- Your responses must rely only on the provided context, not prior knowledge.
|
| 233 |
+
- If the provided context doesn't help answer the question, just say you don't know.
|
| 234 |
+
- When providing code snippets, ensure the functions, classes, or methods are derived only from the context and not prior knowledge.
|
| 235 |
+
- Where the provided context is insufficient to respond faithfully, admit uncertainty.
|
| 236 |
+
- Remind the user of your specialization in Argilla SDK support when a question is outside your domain of expertise.
|
| 237 |
+
- Redirect the user to the appropriate support channels - Argilla [community](https://join.slack.com/t/rubrixworkspace/shared_invite/zt-whigkyjn-a3IUJLD7gDbTZ0rKlvcJ5g) when the question is outside your capabilities or you do not have enough context to answer the question.
|
| 238 |
+
|
| 239 |
+
**Response Style**
|
| 240 |
+
- Use clear, concise, professional language suitable for technical support
|
| 241 |
+
- Do not refer to the context in the response (e.g., "As mentioned in the context...") instead, provide the information directly in the response.
|
| 242 |
+
|
| 243 |
+
**Example**:
|
| 244 |
+
|
| 245 |
+
The correct answer to the user's query
|
| 246 |
+
|
| 247 |
+
Steps to solve the problem:
|
| 248 |
+
- **Step 1**: ...
|
| 249 |
+
- **Step 2**: ...
|
| 250 |
+
...
|
| 251 |
+
|
| 252 |
+
Here's a code snippet
|
| 253 |
+
|
| 254 |
+
```python
|
| 255 |
+
# Code example
|
| 256 |
+
...
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
**Explanation**:
|
| 260 |
+
|
| 261 |
+
- Point 1
|
| 262 |
+
- Point 2
|
| 263 |
+
...
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
ARGILLA_BOT_TEMPLATE = """\
|
| 267 |
+
Please provide an answer to the following question related to Argilla's new SDK.
|
| 268 |
+
|
| 269 |
+
You can make use of the chunks of documents in the context to help you generating the response.
|
| 270 |
+
|
| 271 |
+
## Query:
|
| 272 |
+
{message}
|
| 273 |
+
|
| 274 |
+
## Context:
|
| 275 |
+
{context}
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def prepare_input(message: str, history: list[tuple[str, str]]) -> str:
|
| 280 |
+
"""Prepares the input to be passed to the LLM.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
message: Message from the user, the query.
|
| 284 |
+
history: Previous list of messages from the user and the answers, as a list
|
| 285 |
+
of tuples with user/assistant messages.
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
The string with the template formatted to be sent to the LLM.
|
| 289 |
+
"""
|
| 290 |
+
# Retrieve the context from the database
|
| 291 |
+
context = database.retrieve_doc_chunks(message)
|
| 292 |
+
|
| 293 |
+
# Prepare the conversation for the model.
|
| 294 |
+
conversation = []
|
| 295 |
+
for human, bot in history:
|
| 296 |
+
conversation.append({"role": "user", "content": human})
|
| 297 |
+
conversation.append({"role": "assistant", "content": bot})
|
| 298 |
+
|
| 299 |
+
conversation.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
|
| 300 |
+
conversation.append(
|
| 301 |
+
{
|
| 302 |
+
"role": "user",
|
| 303 |
+
"content": ARGILLA_BOT_TEMPLATE.format(message=message, context=context),
|
| 304 |
+
}
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
return tokenizer.apply_chat_template(
|
| 308 |
+
[conversation],
|
| 309 |
+
tokenize=False,
|
| 310 |
+
add_generation_prompt=True,
|
| 311 |
+
)[0]
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def chatty(message: str, history: list[tuple[str, str]]) -> Generator[str, None, None]:
|
| 315 |
+
"""Main function of the app, contains the interaction with the LLM.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
message: Message from the user, the query.
|
| 319 |
+
history: Previous list of messages from the user and the answers, as a list
|
| 320 |
+
of tuples with user/assistant messages.
|
| 321 |
+
|
| 322 |
+
Yields:
|
| 323 |
+
The streaming response, it's printed in the interface as it's being received.
|
| 324 |
+
"""
|
| 325 |
+
prompt = prepare_input(message, history)
|
| 326 |
+
|
| 327 |
+
partial_message = ""
|
| 328 |
+
for token_stream in client.text_generation(prompt=prompt, **client_kwargs):
|
| 329 |
+
partial_message += token_stream
|
| 330 |
+
yield partial_message
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
if __name__ == "__main__":
|
| 335 |
+
|
| 336 |
+
import gradio as gr
|
| 337 |
+
|
| 338 |
+
gr.ChatInterface(
|
| 339 |
+
chatty,
|
| 340 |
+
chatbot=gr.Chatbot(height=600),
|
| 341 |
+
textbox=gr.Textbox(placeholder="Ask me about the new argilla SDK", container=False, scale=7),
|
| 342 |
+
title="Argilla SDK Chatbot",
|
| 343 |
+
description="Ask a question about Argilla SDK",
|
| 344 |
+
theme="soft",
|
| 345 |
+
examples=[
|
| 346 |
+
"How can I connect to an argilla server?",
|
| 347 |
+
"How can I access a dataset?",
|
| 348 |
+
"How can I get the current user?"
|
| 349 |
+
],
|
| 350 |
+
cache_examples=True,
|
| 351 |
+
retry_btn=None,
|
| 352 |
+
).launch()
|