Spaces:
Sleeping
Sleeping
| import copy | |
| import re | |
| from abc import ABC, abstractmethod | |
| from typing import Callable, Iterable, List, Literal, Optional, Union | |
| from graphgen.bases.datatypes import Chunk | |
| from graphgen.utils import logger | |
| class BaseSplitter(ABC): | |
| """ | |
| Abstract base class for splitting text into smaller chunks. | |
| """ | |
| def __init__( | |
| self, | |
| chunk_size: int = 1024, | |
| chunk_overlap: int = 100, | |
| length_function: Callable[[str], int] = len, | |
| keep_separator: bool = False, | |
| add_start_index: bool = False, | |
| strip_whitespace: bool = True, | |
| ): | |
| self.chunk_size = chunk_size | |
| self.chunk_overlap = chunk_overlap | |
| self.length_function = length_function | |
| self.keep_separator = keep_separator | |
| self.add_start_index = add_start_index | |
| self.strip_whitespace = strip_whitespace | |
| def split_text(self, text: str) -> List[str]: | |
| """ | |
| Split the input text into smaller chunks. | |
| :param text: The input text to be split. | |
| :return: A list of text chunks. | |
| """ | |
| def create_chunks( | |
| self, texts: List[str], metadatas: Optional[List[dict]] = None | |
| ) -> List[Chunk]: | |
| """Create chunks from a list of texts.""" | |
| _metadatas = metadatas or [{}] * len(texts) | |
| chunks = [] | |
| for i, text in enumerate(texts): | |
| index = 0 | |
| previous_chunk_len = 0 | |
| for chunk in self.split_text(text): | |
| metadata = copy.deepcopy(_metadatas[i]) | |
| if self.add_start_index: | |
| offset = index + previous_chunk_len - self.chunk_overlap | |
| index = text.find(chunk, max(0, offset)) | |
| metadata["start_index"] = index | |
| previous_chunk_len = len(chunk) | |
| new_chunk = Chunk(content=chunk, metadata=metadata) | |
| chunks.append(new_chunk) | |
| return chunks | |
| def _join_chunks(self, chunks: List[str], separator: str) -> Optional[str]: | |
| text = separator.join(chunks) | |
| if self.strip_whitespace: | |
| text = text.strip() | |
| if text == "": | |
| return None | |
| return text | |
| def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: | |
| # We now want to combine these smaller pieces into medium size chunks to send to the LLM. | |
| separator_len = self.length_function(separator) | |
| chunks = [] | |
| current_chunk: List[str] = [] | |
| total = 0 | |
| for d in splits: | |
| _len = self.length_function(d) | |
| if ( | |
| total + _len + (separator_len if len(current_chunk) > 0 else 0) | |
| > self.chunk_size | |
| ): | |
| if total > self.chunk_size: | |
| logger.warning( | |
| "Created a chunk of size %s, which is longer than the specified %s", | |
| total, | |
| self.chunk_size, | |
| ) | |
| if len(current_chunk) > 0: | |
| chunk = self._join_chunks(current_chunk, separator) | |
| if chunk is not None: | |
| chunks.append(chunk) | |
| # Keep on popping if: | |
| # - we have a larger chunk than in the chunk overlap | |
| # - or if we still have any chunks and the length is long | |
| while total > self.chunk_overlap or ( | |
| total + _len + (separator_len if len(current_chunk) > 0 else 0) | |
| > self.chunk_size | |
| and total > 0 | |
| ): | |
| total -= self.length_function(current_chunk[0]) + ( | |
| separator_len if len(current_chunk) > 1 else 0 | |
| ) | |
| current_chunk = current_chunk[1:] | |
| current_chunk.append(d) | |
| total += _len + (separator_len if len(current_chunk) > 1 else 0) | |
| chunk = self._join_chunks(current_chunk, separator) | |
| if chunk is not None: | |
| chunks.append(chunk) | |
| return chunks | |
| def _split_text_with_regex( | |
| text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]] | |
| ) -> List[str]: | |
| # Now that we have the separator, split the text | |
| if separator: | |
| if keep_separator: | |
| # The parentheses in the pattern keep the delimiters in the result. | |
| _splits = re.split(f"({separator})", text) | |
| splits = ( | |
| ( | |
| [ | |
| _splits[i] + _splits[i + 1] | |
| for i in range(0, len(_splits) - 1, 2) | |
| ] | |
| ) | |
| if keep_separator == "end" | |
| else ( | |
| [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] | |
| ) | |
| ) | |
| if len(_splits) % 2 == 0: | |
| splits += _splits[-1:] | |
| splits = ( | |
| (splits + [_splits[-1]]) | |
| if keep_separator == "end" | |
| else ([_splits[0]] + splits) | |
| ) | |
| else: | |
| splits = re.split(separator, text) | |
| else: | |
| splits = list(text) | |
| return [s for s in splits if s != ""] | |