Spaces:
Sleeping
Sleeping
File size: 5,422 Bytes
43d27f2 8c66169 43d27f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import copy
import re
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List, Literal, Optional, Union
from graphgen.bases.datatypes import Chunk
from graphgen.utils import logger
class BaseSplitter(ABC):
"""
Abstract base class for splitting text into smaller chunks.
"""
def __init__(
self,
chunk_size: int = 1024,
chunk_overlap: int = 100,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.length_function = length_function
self.keep_separator = keep_separator
self.add_start_index = add_start_index
self.strip_whitespace = strip_whitespace
@abstractmethod
def split_text(self, text: str) -> List[str]:
"""
Split the input text into smaller chunks.
:param text: The input text to be split.
:return: A list of text chunks.
"""
def create_chunks(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Chunk]:
"""Create chunks from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
chunks = []
for i, text in enumerate(texts):
index = 0
previous_chunk_len = 0
for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i])
if self.add_start_index:
offset = index + previous_chunk_len - self.chunk_overlap
index = text.find(chunk, max(0, offset))
metadata["start_index"] = index
previous_chunk_len = len(chunk)
new_chunk = Chunk(content=chunk, metadata=metadata)
chunks.append(new_chunk)
return chunks
def _join_chunks(self, chunks: List[str], separator: str) -> Optional[str]:
text = separator.join(chunks)
if self.strip_whitespace:
text = text.strip()
if text == "":
return None
return text
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size chunks to send to the LLM.
separator_len = self.length_function(separator)
chunks = []
current_chunk: List[str] = []
total = 0
for d in splits:
_len = self.length_function(d)
if (
total + _len + (separator_len if len(current_chunk) > 0 else 0)
> self.chunk_size
):
if total > self.chunk_size:
logger.warning(
"Created a chunk of size %s, which is longer than the specified %s",
total,
self.chunk_size,
)
if len(current_chunk) > 0:
chunk = self._join_chunks(current_chunk, separator)
if chunk is not None:
chunks.append(chunk)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self.chunk_overlap or (
total + _len + (separator_len if len(current_chunk) > 0 else 0)
> self.chunk_size
and total > 0
):
total -= self.length_function(current_chunk[0]) + (
separator_len if len(current_chunk) > 1 else 0
)
current_chunk = current_chunk[1:]
current_chunk.append(d)
total += _len + (separator_len if len(current_chunk) > 1 else 0)
chunk = self._join_chunks(current_chunk, separator)
if chunk is not None:
chunks.append(chunk)
return chunks
@staticmethod
def _split_text_with_regex(
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = (
(
[
_splits[i] + _splits[i + 1]
for i in range(0, len(_splits) - 1, 2)
]
)
if keep_separator == "end"
else (
[_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
)
)
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = (
(splits + [_splits[-1]])
if keep_separator == "end"
else ([_splits[0]] + splits)
)
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
|