GraphGen / graphgen /bases /base_splitter.py
github-actions[bot]
Auto-sync from demo at Thu Oct 23 12:37:24 UTC 2025
8c66169
raw
history blame
5.42 kB
import copy
import re
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List, Literal, Optional, Union
from graphgen.bases.datatypes import Chunk
from graphgen.utils import logger
class BaseSplitter(ABC):
"""
Abstract base class for splitting text into smaller chunks.
"""
def __init__(
self,
chunk_size: int = 1024,
chunk_overlap: int = 100,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.length_function = length_function
self.keep_separator = keep_separator
self.add_start_index = add_start_index
self.strip_whitespace = strip_whitespace
@abstractmethod
def split_text(self, text: str) -> List[str]:
"""
Split the input text into smaller chunks.
:param text: The input text to be split.
:return: A list of text chunks.
"""
def create_chunks(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Chunk]:
"""Create chunks from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
chunks = []
for i, text in enumerate(texts):
index = 0
previous_chunk_len = 0
for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i])
if self.add_start_index:
offset = index + previous_chunk_len - self.chunk_overlap
index = text.find(chunk, max(0, offset))
metadata["start_index"] = index
previous_chunk_len = len(chunk)
new_chunk = Chunk(content=chunk, metadata=metadata)
chunks.append(new_chunk)
return chunks
def _join_chunks(self, chunks: List[str], separator: str) -> Optional[str]:
text = separator.join(chunks)
if self.strip_whitespace:
text = text.strip()
if text == "":
return None
return text
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size chunks to send to the LLM.
separator_len = self.length_function(separator)
chunks = []
current_chunk: List[str] = []
total = 0
for d in splits:
_len = self.length_function(d)
if (
total + _len + (separator_len if len(current_chunk) > 0 else 0)
> self.chunk_size
):
if total > self.chunk_size:
logger.warning(
"Created a chunk of size %s, which is longer than the specified %s",
total,
self.chunk_size,
)
if len(current_chunk) > 0:
chunk = self._join_chunks(current_chunk, separator)
if chunk is not None:
chunks.append(chunk)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self.chunk_overlap or (
total + _len + (separator_len if len(current_chunk) > 0 else 0)
> self.chunk_size
and total > 0
):
total -= self.length_function(current_chunk[0]) + (
separator_len if len(current_chunk) > 1 else 0
)
current_chunk = current_chunk[1:]
current_chunk.append(d)
total += _len + (separator_len if len(current_chunk) > 1 else 0)
chunk = self._join_chunks(current_chunk, separator)
if chunk is not None:
chunks.append(chunk)
return chunks
@staticmethod
def _split_text_with_regex(
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = (
(
[
_splits[i] + _splits[i + 1]
for i in range(0, len(_splits) - 1, 2)
]
)
if keep_separator == "end"
else (
[_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
)
)
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = (
(splits + [_splits[-1]])
if keep_separator == "end"
else ([_splits[0]] + splits)
)
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]