Spaces:
Sleeping
Sleeping
| """ | |
| The widgets defines utility functions for loading data, text cleaning, | |
| and indexing documents, as well as classes for handling document queries | |
| and formatting chat history. | |
| """ | |
| import re | |
| import pickle | |
| import string | |
| import logging | |
| import configparser | |
| from enum import Enum | |
| from typing import List, Tuple, Union | |
| import nltk | |
| from nltk.stem import WordNetLemmatizer | |
| from nltk.tokenize import word_tokenize | |
| from nltk.corpus import stopwords | |
| import torch | |
| import tiktoken | |
| from langchain.vectorstores import Chroma | |
| from langchain.schema import Document, BaseMessage | |
| from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| tokenizer_name = tiktoken.encoding_for_model("gpt-3.5-turbo") | |
| tokenizer = tiktoken.get_encoding(tokenizer_name.name) | |
| # if nltk stopwords, punkt and wordnet are not downloaded, download it | |
| try: | |
| nltk.data.find("corpora/stopwords") | |
| except LookupError: | |
| nltk.download("stopwords") | |
| try: | |
| nltk.data.find("tokenizers/punkt") | |
| except LookupError: | |
| nltk.download("punkt") | |
| try: | |
| nltk.data.find("corpora/wordnet") | |
| except LookupError: | |
| nltk.download("wordnet") | |
| ChatTurnType = Union[Tuple[str, str], BaseMessage] | |
| _ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "} | |
| class Config: | |
| """Initializes configs.""" | |
| def __init__(self, config_file): | |
| self.config = configparser.ConfigParser(interpolation=None) | |
| self.config.read(config_file) | |
| # Tokens | |
| self.openai_api_key = self.config.get("tokens", "OPENAI_API_KEY") | |
| self.anthropic_api_key = self.config.get("tokens", "ANTHROPIC_API_KEY") | |
| self.together_api_key = self.config.get("tokens", "TOGETHER_API_KEY") | |
| self.huggingface_token = self.config.get("tokens", "HUGGINGFACE_TOKEN") | |
| self.version = self.config.get("tokens", "VERSION") | |
| # Directory | |
| self.docs_dir = self.config.get("directory", "DOCS_DIR") | |
| self.db_dir = self.config.get("directory", "db_DIR") | |
| self.local_model_dir = self.config.get("directory", "LOCAL_MODEL_DIR") | |
| # Parameters | |
| self.model_name = self.config.get("parameters", "MODEL_NAME") | |
| self.temperature = self.config.getfloat("parameters", "TEMPURATURE") | |
| self.max_chat_history = self.config.getint("parameters", "MAX_CHAT_HISTORY") | |
| self.max_llm_context = self.config.getint("parameters", "MAX_LLM_CONTEXT") | |
| self.max_llm_generation = self.config.getint("parameters", "MAX_LLM_GENERATION") | |
| self.embedding_name = self.config.get("parameters", "EMBEDDING_NAME") | |
| self.n_gpu_layers = self.config.getint("parameters", "N_GPU_LAYERS") | |
| self.n_batch = self.config.getint("parameters", "N_BATCH") | |
| self.base_chunk_size = self.config.getint("parameters", "BASE_CHUNK_SIZE") | |
| self.chunk_overlap = self.config.getint("parameters", "CHUNK_OVERLAP") | |
| self.chunk_scale = self.config.getint("parameters", "CHUNK_SCALE") | |
| self.window_steps = self.config.getint("parameters", "WINDOW_STEPS") | |
| self.window_scale = self.config.getint("parameters", "WINDOW_SCALE") | |
| self.retriever_weights = [ | |
| float(x.strip()) | |
| for x in self.config.get("parameters", "RETRIEVER_WEIGHTS").split(",") | |
| ] | |
| self.first_retrieval_k = self.config.getint("parameters", "FIRST_RETRIEVAL_K") | |
| self.second_retrieval_k = self.config.getint("parameters", "SECOND_RETRIEVAL_K") | |
| self.num_windows = self.config.getint("parameters", "NUM_WINDOWS") | |
| # Logging | |
| self.logging_enabled = self.config.getboolean("logging", "enabled") | |
| self.logging_level = self.config.get("logging", "level") | |
| self.logging_filename = self.config.get("logging", "filename") | |
| self.logging_format = self.config.get("logging", "format") | |
| self.configure_logging() | |
| def configure_logging(self): | |
| """ | |
| Configure the logger for each .py files. | |
| """ | |
| if not self.logging_enabled: | |
| logging.disable(logging.CRITICAL + 1) | |
| return | |
| log_level = self.config.get("logging", "level") | |
| log_filename = self.config.get("logging", "filename") | |
| log_format = self.config.get("logging", "format") | |
| logging.basicConfig(level=log_level, filename=log_filename, format=log_format) | |
| def configure_logger(): | |
| """ | |
| Configure the logger for each .py files. | |
| """ | |
| config = configparser.ConfigParser(interpolation=None) | |
| config.read("configparser.ini") | |
| enabled = config.getboolean("logging", "enabled") | |
| if not enabled: | |
| logging.disable(logging.CRITICAL + 1) | |
| return | |
| log_level = config.get("logging", "level") | |
| log_filename = config.get("logging", "filename") | |
| log_format = config.get("logging", "format") | |
| logging.basicConfig(level=log_level, filename=log_filename, format=log_format) | |
| def tiktoken_len(text): | |
| """token length function""" | |
| tokens = tokenizer.encode(text, disallowed_special=()) | |
| return len(tokens) | |
| def check_device(): | |
| """Check if cuda or MPS is available, else fallback to CPU""" | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| return device | |
| def choose_embeddings(embedding_name): | |
| """Choose embeddings for a given model's name""" | |
| try: | |
| if embedding_name == "openAIEmbeddings": | |
| return OpenAIEmbeddings() | |
| elif embedding_name == "hkunlpInstructorLarge": | |
| device = check_device() | |
| return HuggingFaceInstructEmbeddings( | |
| model_name="hkunlp/instructor-large", model_kwargs={"device": device} | |
| ) | |
| else: | |
| device = check_device() | |
| return HuggingFaceEmbeddings(model_name=embedding_name, device=device) | |
| except Exception as error: | |
| raise ValueError(f"Embedding {embedding_name} not supported") from error | |
| def load_embedding(store_name, embedding, suffix, path): | |
| """Load chroma embeddings""" | |
| vector_store = Chroma( | |
| persist_directory=f"{path}/chroma_{store_name}_{suffix}", | |
| embedding_function=embedding, | |
| ) | |
| return vector_store | |
| def load_pickle(prefix, suffix, path): | |
| """Load langchain documents from a pickle file. | |
| Args: | |
| store_name (str): The name of the store where data is saved. | |
| suffix (str): Suffix to append to the store name. | |
| path (str): The path where the pickle file is stored. | |
| Returns: | |
| Document: documents from the pickle file | |
| """ | |
| with open(f"{path}/{prefix}_{suffix}.pkl", "rb") as file: | |
| return pickle.load(file) | |
| def clean_text(text): | |
| """ | |
| Converts text to lowercase, removes punctuation, stopwords, and lemmatizes it | |
| for BM25 retriever. | |
| Parameters: | |
| text (str): The text to be cleaned. | |
| Returns: | |
| str: The cleaned and lemmatized text. | |
| """ | |
| # remove [SEP] in the text | |
| text = text.replace("[SEP]", "") | |
| # Tokenization | |
| tokens = word_tokenize(text) | |
| # Lowercasing | |
| tokens = [w.lower() for w in tokens] | |
| # Remove punctuation | |
| table = str.maketrans("", "", string.punctuation) | |
| stripped = [w.translate(table) for w in tokens] | |
| # Keep tokens that are alphabetic, numeric, or contain both. | |
| words = [ | |
| word | |
| for word in stripped | |
| if word.isalpha() | |
| or word.isdigit() | |
| or (re.search("\d", word) and re.search("[a-zA-Z]", word)) | |
| ] | |
| # Remove stopwords | |
| stop_words = set(stopwords.words("english")) | |
| words = [w for w in words if w not in stop_words] | |
| # Lemmatization (or you could use stemming instead) | |
| lemmatizer = WordNetLemmatizer() | |
| lemmatized = [lemmatizer.lemmatize(w) for w in words] | |
| # Convert list of words to a string | |
| lemmatized_ = " ".join(lemmatized) | |
| return lemmatized_ | |
| class IndexerOperator(Enum): | |
| """ | |
| Enumeration for different query operators used in indexing. | |
| """ | |
| EQ = "==" | |
| GT = ">" | |
| GTE = ">=" | |
| LT = "<" | |
| LTE = "<=" | |
| class DocIndexer: | |
| """ | |
| A class to handle indexing and searching of documents. | |
| Attributes: | |
| documents (List[Document]): List of documents to be indexed. | |
| """ | |
| def __init__(self, documents): | |
| self.documents = documents | |
| self.index = self.build_index(documents) | |
| def build_index(self, documents): | |
| """ | |
| Build an index for the given list of documents. | |
| Parameters: | |
| documents (List[Document]): The list of documents to be indexed. | |
| Returns: | |
| dict: The built index. | |
| """ | |
| index = {} | |
| for doc in documents: | |
| for key, value in doc.metadata.items(): | |
| if key not in index: | |
| index[key] = {} | |
| if value not in index[key]: | |
| index[key][value] = [] | |
| index[key][value].append(doc) | |
| return index | |
| def retrieve_metadata(self, search_dict): | |
| """ | |
| Retrieve documents based on the search criteria provided in search_dict. | |
| Parameters: | |
| search_dict (dict): Dictionary specifying the search criteria. | |
| It can contain "AND" or "OR" operators for | |
| complex queries. | |
| Returns: | |
| List[Document]: List of documents that match the search criteria. | |
| """ | |
| if "AND" in search_dict: | |
| return self._handle_and(search_dict["AND"]) | |
| elif "OR" in search_dict: | |
| return self._handle_or(search_dict["OR"]) | |
| else: | |
| return self._handle_single(search_dict) | |
| def _handle_and(self, search_dicts): | |
| results = [self.retrieve_metadata(sd) for sd in search_dicts] | |
| if results: | |
| intersection = set.intersection( | |
| *[set(map(self._hash_doc, r)) for r in results] | |
| ) | |
| return [self._unhash_doc(h) for h in intersection] | |
| else: | |
| return [] | |
| def _handle_or(self, search_dicts): | |
| results = [self.retrieve_metadata(sd) for sd in search_dicts] | |
| union = set.union(*[set(map(self._hash_doc, r)) for r in results]) | |
| return [self._unhash_doc(h) for h in union] | |
| def _handle_single(self, search_dict): | |
| unions = [] | |
| for key, query in search_dict.items(): | |
| operator, value = query | |
| union = set() | |
| if operator == IndexerOperator.EQ: | |
| if key in self.index and value in self.index[key]: | |
| union.update(map(self._hash_doc, self.index[key][value])) | |
| else: | |
| if key in self.index: | |
| for k, v in self.index[key].items(): | |
| if ( | |
| (operator == IndexerOperator.GT and k > value) | |
| or (operator == IndexerOperator.GTE and k >= value) | |
| or (operator == IndexerOperator.LT and k < value) | |
| or (operator == IndexerOperator.LTE and k <= value) | |
| ): | |
| union.update(map(self._hash_doc, v)) | |
| if union: | |
| unions.append(union) | |
| if unions: | |
| intersection = set.intersection(*unions) | |
| return [self._unhash_doc(h) for h in intersection] | |
| else: | |
| return [] | |
| def _hash_doc(self, doc): | |
| return (doc.page_content, frozenset(doc.metadata.items())) | |
| def _unhash_doc(self, hashed_doc): | |
| page_content, metadata = hashed_doc | |
| return Document(page_content=page_content, metadata=dict(metadata)) | |
| def _get_chat_history(chat_history: List[ChatTurnType]) -> str: | |
| buffer = "" | |
| for dialogue_turn in chat_history: | |
| if isinstance(dialogue_turn, BaseMessage): | |
| role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ") | |
| buffer += f"\n{role_prefix}{dialogue_turn.content}" | |
| elif isinstance(dialogue_turn, tuple): | |
| human = "Human: " + dialogue_turn[0] | |
| ai = "Assistant: " + dialogue_turn[1] | |
| buffer += "\n" + "\n".join([human, ai]) | |
| else: | |
| raise ValueError( | |
| f"Unsupported chat history format: {type(dialogue_turn)}." | |
| f" Full chat history: {chat_history} " | |
| ) | |
| return buffer | |
| def _get_standalone_questions_list( | |
| standalone_questions_str: str, original_question: str | |
| ) -> List[str]: | |
| pattern = r"\d+\.\s(.*?)(?=\n\d+\.|\n|$)" | |
| matches = [ | |
| match.group(1) for match in re.finditer(pattern, standalone_questions_str) | |
| ] | |
| if matches: | |
| return matches | |
| match = re.search( | |
| r"(?i)standalone[^\n]*:[^\n](.*)", standalone_questions_str, re.DOTALL | |
| ) | |
| sentence_source = match.group(1).strip() if match else standalone_questions_str | |
| sentences = sentence_source.split("\n") | |
| return [ | |
| re.sub( | |
| r"^\((\d+)\)\.? ?|^\d+\.? ?\)?|^(\d+)\) ?|^(\d+)\) ?|^[Qq]uery \d+: ?|^[Qq]uery: ?", | |
| "", | |
| sentence.strip(), | |
| ) | |
| for sentence in sentences | |
| if sentence.strip() | |
| ] | |