Gemini
fix: Resolve runtime errors and timeout issues\n\n- Create writable directories for logs and Hugging Face cache.\n- Set TRANSFORMERS_CACHE environment variable.\n- Fix SyntaxWarning in lpm_kernel/utils.py.\n- Disable ChromaDB telemetry to prevent posthog errors.
e09e6c3
raw
history blame
28.8 kB
import logging
from enum import Enum
import tiktoken
import re
from typing import Any, Optional, Union, Collection, AbstractSet, Literal, List
from langchain.text_splitter import TextSplitter
import random
import string
from itertools import chain
import json
from lpm_kernel.configs.logging import get_train_process_logger
logger = get_train_process_logger()
class IntentType(Enum):
Emotion = "Emotion"
Knowledge = "Knowledge"
def select_language_desc(
preferred_language,
default_desc="Identify the language of the provided Hint. Your response must be in the same language.",
):
custom_desc = "You must respond in {}."
if isinstance(preferred_language, str) and "/" in preferred_language:
native, es = preferred_language.split("/")
logging.info(f"Native: {native}, ES: {es}")
return custom_desc.format(es)
else:
logging.info(
"Error: preferred_language is not in the correct format. It should be 'native/es'."
)
return default_desc
def cal_upperbound(
model_limit: int = 4096,
generage_limit: int = 512,
tolerance: int = 500,
raw: str = "",
model_name: str = "gpt-3.5-turbo",
) -> int:
"""
:param model_limit: Maximum token count for the underlying model call
:param tolerance: Error tolerance buffer
:param raw: system prompt and raw content
:return:
"""
if model_name is not None:
if model_name in tiktoken.model.MODEL_TO_ENCODING:
enc = tiktoken.encoding_for_model(model_name)
logging.info(f"Successfully initialized tokenizer for model: {model_name}")
else:
enc = tiktoken.get_encoding("cl100k_base")
logging.warning(f"Model '{model_name}' doesn't have a corresponding tokenizer, falling back to default: cl100k_base")
else:
enc = tiktoken.get_encoding("cl100k_base")
logging.info(f"No model specified, using default tokenizer: cl100k_base")
raw_token = len(enc.encode(raw))
upper_bound = model_limit - raw_token - tolerance - generage_limit
if upper_bound < 0:
logging.info(f"raw content is too long: {raw_token}")
return 0
return upper_bound
def equidistant_filter(chunks, separator, filtered_chunks_n=6):
# Select the first and last two chunks, sample the remaining chunks evenly from the middle
gap = (len(chunks) - 2) / (filtered_chunks_n - 2)
indexes = [
int(gap * i)
for i in range(int(len(chunks) / gap) + 1)
if (gap * i < len(chunks) - 2)
]
filtered_chunks = [chunks[i] for i in indexes]
filtered_chunks.append(separator.join(chunks[-2:]))
return filtered_chunks
def tab_or_space_replacement(match):
# If there is a tab character in the matched string, replace it with a single tab, otherwise replace it with a single space
return "\t" if "\t" in match.group() else " "
def text_filter(text: str) -> str:
pattern_tab_space = "[ \t]{3,}"
pattern_wordwrap = "[\n\f\r\v]{3,}"
# Replace when encountering three or more spaces or tabs
replaced_text = re.sub(pattern_tab_space, tab_or_space_replacement, text)
# When there are multiple consecutive \n (newline), \f (form feed), \r (carriage return), \v (vertical tab), replace them with 2 original newlines
replaced_text = re.sub(pattern_wordwrap, "\n\n", replaced_text)
return replaced_text
ALLOW_SPECIAL_TOKEN = {"<|endofprompt|>", "<|endoftext|>"}
def find_sublist_indices(main_list, sublist):
indices = []
length = len(sublist)
for i in range(len(main_list) - length + 1):
if main_list[i : i + length] == sublist:
indices.append((i, i + length))
return indices
class TokenTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at tokens."""
def __init__(
self,
encoding_name: str = "cl100k_base",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = ALLOW_SPECIAL_TOKEN,
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to for TokenTextSplitter. "
"Please it install it with `pip install tiktoken`."
)
# create a GPT-3 encoder instance
if model_name is not None:
if model_name in tiktoken.model.MODEL_TO_ENCODING:
enc = tiktoken.encoding_for_model(model_name)
logging.info(f"Successfully initialized tokenizer for model: {model_name}")
else:
enc = tiktoken.get_encoding(encoding_name)
logging.warning(f"Model '{model_name}' doesn't have a corresponding tokenizer, falling back to default: {encoding_name}")
else:
enc = tiktoken.get_encoding(encoding_name)
logging.info(f"No model specified, using default tokenizer: {encoding_name}")
self._tokenizer = enc
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# Filter content with a large number of whitespace characters in the input text to increase the proportion of effective content within chunks
text = text_filter(text)
splits = []
input_ids = self._tokenizer.encode(
text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
start_idx = 0
while start_idx < len(input_ids):
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
s = self._tokenizer.decode(chunk_ids).strip()
if s:
s = self._cut_meaningless_head_tail(s)
if s:
splits.append(s)
start_idx += self._chunk_size - self._chunk_overlap
logging.debug("finished split_text(): %s splits", len(splits))
return splits
def _cut_meaningless_head_tail(self, text: str) -> str:
# Only split when there are multiple newlines, as parsing of PDF/Word often contains false newlines
sentences = re.split(r"\. |! |\? |。|!|?|\n+ *\n+", text)
if len(sentences) < 2:
return text
head = sentences[0]
body = ". ".join(sentences[1:-1])
tail = sentences[-1]
head_len = len(
self._tokenizer.encode(
body,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
)
body_len = len(
self._tokenizer.encode(
body,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
)
tail_len = len(
self._tokenizer.encode(
tail,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
)
parts = []
# Use length to roughly estimate the impact of discarding the tail; if the impact is not significant, discard it
# Rough estimate: Chinese 20 tokens, 8 characters; English 10 tokens, 30 characters
if head_len >= 20 or len(head) >= 30:
parts.append(head)
if body_len > 0:
parts.append(body)
if tail_len >= 20 or len(tail) >= 30:
parts.append(tail)
res = "\n".join(parts)
logger.info(
"_cut_meaningless_tail() removes redundant sentence tails from chunks, before cut: %s characters, after cut: %s characters",
len(text),
len(res),
)
return res
def chunk_filter(
chunks, filter, filtered_chunks_n=6, separator="\n", spacer="\n……\n……\n……\n"
):
if len(chunks) <= filtered_chunks_n:
return separator.join(chunks)
return spacer.join(filter(chunks, separator, filtered_chunks_n))
def get_safe_content_turncate(content, model_name="gpt-3.5-turbo", max_tokens=3300):
if model_name is not None:
if model_name in tiktoken.model.MODEL_TO_ENCODING:
enc = tiktoken.encoding_for_model(model_name)
logging.info(f"Successfully initialized tokenizer for model: {model_name}")
else:
enc = tiktoken.get_encoding("cl100k_base")
logging.warning(f"Model '{model_name}' doesn't have a corresponding tokenizer, falling back to default: cl100k_base")
else:
enc = tiktoken.get_encoding("cl100k_base")
logging.info(f"No model specified, using default tokenizer: cl100k_base")
logging.warning(
"get_safe_content_turncate(): current model maximum input length is %s, current input length is %s",
max_tokens,
len(enc.encode(content)),
)
if len(enc.encode(content)) > max_tokens:
content = enc.decode(enc.encode(content)[:max_tokens])
return content
class DataType(Enum):
DOCUMENT = "DOCUMENT"
WEBSITE = "WEBSITE"
IMAGE = "IMAGE"
TABLE = "TABLE"
AUDIO = "AUDIO"
TEXT = "TEXT"
@staticmethod
def extra_values_map():
return {
"SHORT_AUDIO": "AUDIO",
}
@classmethod
def _missing_(cls, value):
# Try to find the corresponding primary key value from the extra value mapping
extra_map = cls.extra_values_map()
if value in extra_map:
value = extra_map[value]
return cls.__members__.get(value)
# If not found, return DOCUMENT by default
logging.error("DataType._missing_(): Could not find corresponding DataType enum value: %s", value)
return cls.DOCUMENT
def get_urls(string):
url_arr = []
if not string:
return url_arr
pattern = re.compile(
r"(https?|ftp|file)://[-A-Za-z0-9+&@#/%?=~_|!:,.;\u4e00-\u9fa5]+[-A-Za-z0-9+&@#/%=~_|]"
)
matcher = pattern.finditer(string)
for match in matcher:
url_arr.append(match.group())
sorted_url_arr = sorted(set(url_arr), key=len, reverse=True)
return sorted_url_arr
def get_random_string(s_length: int) -> str:
# Generate a random string
letters = string.ascii_letters + string.digits
return "".join(random.choice(letters) for i in range(s_length))
def get_random_strings(n: int, s_length: int) -> List[str]:
unique_strings = set()
while len(unique_strings) < n:
unique_strings.add(get_random_string(s_length))
return list(unique_strings)
def encode_urls(text, random_string_len: int = 16):
urls = get_urls(text)
random_strings = get_random_strings(len(urls), random_string_len)
url2string_dict = dict(zip(urls, random_strings))
string2url_dict = dict(zip(random_strings, urls))
for url, random_string in url2string_dict.items():
text = text.replace(url, random_string)
return text, string2url_dict
def decode_urls(text, string2url_dict):
for random_string, url in string2url_dict.items():
text = text.replace(random_string, url)
return text
class TokenParagraphSplitter(TextSplitter):
"""For business data characteristics, perform some additional processing. This includes:
1. Complete fragments as independent chunks help improve information focus in each chunk. Complete fragments are mainly determined by period+newline.
2. When complete fragments are too long, split them into sentences and combine sentences into chunks that meet window size limits
3. If a sentence is too long, split it directly by token granularity
"""
line_break_characters = ["\n", "\f", "\r", "\v"]
whitespace_characters = [" ", "\t"]
sentence_terminators = [
".",
"!",
"?",
"。",
"!",
"?",
"……",
"...",
] + line_break_characters
paired_punctuation = [
("(", ")"),
("[", "]"),
("{", "}"),
("<", ">"),
("“", "”"),
("‘", "’"),
("《", "》"),
("【", "】"),
]
intra_sentence_delimiters = [",", ",", ";", ";"] + whitespace_characters
def __init__(
self,
encoding_name: str = "cl100k_base",
allowed_special: Union[Literal["all"], AbstractSet[str]] = ALLOW_SPECIAL_TOKEN,
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to for TokenTextSplitter. "
"Please it install it with `pip install tiktoken`."
)
# create a GPT-3 encoder instance
self._tokenizer = tiktoken.get_encoding(encoding_name)
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
def split_text(self, text: str) -> List[str]:
chunks = []
# Clean up abnormal whitespace characters in the text, such as replacing 3 or more consecutive \n with \n\n
text = text_filter(text)
# Replace URLs in the text to avoid symbols like ./?/ in URLs interfering with sentence splitting
text, string2url_dict = encode_urls(text)
url_strings = list(string2url_dict.keys())
# Split by paragraphs according to rules
paragraphs = self._split_to_paragraphs(
text, min_paragraph_length=self._chunk_size // 2
)
for i, paragraph in enumerate(paragraphs):
splits = self._split_to_chunks(paragraph, url_strings)
logging.debug(
"paragraph %s/%s %s characters: %s",
i + 1,
len(paragraphs),
len(paragraph),
paragraph,
)
logging.debug(
"paragraph %s/%s split into %s chunks: %s",
i + 1,
len(paragraphs),
len(splits),
splits,
)
chunks.extend(splits)
chunks = [decode_urls(chunk, string2url_dict) for chunk in chunks]
return chunks
def _split_to_chunks(self, text: str, url_strings: List[str] = []) -> List[str]:
sentences = self._split_to_sentences(text, url_strings)
chunks = self._merge_sentences_into_chunks(
sentences, min_chunk_size=self._chunk_size // 2
)
return chunks
def _split_to_paragraphs(
self, text: str, min_paragraph_length: int = 0
) -> List[str]:
"""Currently split the original document into paragraphs directly based on the \n[any space]\n rule."""
line_break_characters = "".join(self.line_break_characters)
whitespace_characters = "".join(self.whitespace_characters)
paragraphs = re.split(
f"([{line_break_characters}]+[{whitespace_characters}]*[{line_break_characters}])+",
text,
)
if len(paragraphs) % 2 == 1:
paragraphs = [""] + paragraphs
paragraphs = [
(paragraphs[i], paragraphs[i + 1])
for i in range(0, len(paragraphs), 2)
if (paragraphs[i] + paragraphs[i + 1]).strip()
]
if not paragraphs:
return []
new_paragraphs = []
cur_paragraph, cur_paragraph_len = "", 0
# merge short or broken paragraphs
for sep, paragraph in paragraphs:
if cur_paragraph_len >= min_paragraph_length and any(
cur_paragraph.endswith(sym) for sym in self.sentence_terminators
):
new_paragraphs.append(cur_paragraph.strip())
cur_paragraph, cur_paragraph_len = "", 0
cur_paragraph_len += len(self._tokenizer.encode(sep + paragraph))
cur_paragraph += sep + paragraph
if cur_paragraph:
new_paragraphs.append(cur_paragraph.strip())
return new_paragraphs
def _split_to_sentences(self, text: str, url_strings: List[str] = []) -> List[str]:
# Use capture groups to preserve sentence separators
pattern = (
f"({'|'.join(re.escape(symbol) for symbol in self.sentence_terminators)})+"
)
parts = re.split(pattern, text)
sentences = []
# Merge by skipping steps to ensure punctuation is added to the end of the corresponding sentence
if len(parts) % 2 == 1:
parts.append("")
sentences = ["".join(parts[i : i + 2]) for i in range(0, len(parts), 2)]
sentences = [s for s in sentences if s.strip()]
if not sentences:
return []
# Fix fragmented sentences, mainly for special cases such as numeric indices, floating-point numbers, etc., which may be separated
sentences = self.recombine_broken_sentences(sentences)
# Split sentences that are too long; in the short term, split directly by character length; future optimizations could consider splitting by punctuation within sentences
sentences_list = [
self._force_split_to_chunks(s, url_strings) for s in sentences
]
sentences = list(chain.from_iterable(sentences_list))
return sentences
def recombine_broken_sentences(self, sentences: List[str]) -> List[str]:
"""Fix fragmented sentences, mainly for special cases such as numeric indices, floating-point numbers, etc., which may be separated。"""
if len(sentences) < 2:
return sentences
open_symbols_dict = {
open_sym: close_sym for open_sym, close_sym in self.paired_punctuation
}
close_symbols_dict = {
close_sym: open_sym for open_sym, close_sym in self.paired_punctuation
}
new_sentences = []
cur_sentences = ""
unmatched_symbol = []
for sent in sentences:
# If the current sentence is not empty, doesn't meet predefined merge conditions, and has no pending matching punctuation ([, (, {, etc.), then consider the sentence complete
if cur_sentences.strip() and not (
self.check_merge(cur_sentences, sent) or unmatched_symbol
):
new_sentences.append(cur_sentences)
cur_sentences = ""
for c in sent:
if c in open_symbols_dict:
unmatched_symbol.append(c)
elif c in close_symbols_dict:
if (
unmatched_symbol
and unmatched_symbol[-1] == close_symbols_dict[c]
):
unmatched_symbol.pop()
# By default, the current sentence ends when a newline-like character appears
if c in self.line_break_characters:
unmatched_symbol = []
if cur_sentences.strip():
new_sentences.append(cur_sentences)
cur_sentences = ""
cur_sentences += c
if cur_sentences:
new_sentences.append(cur_sentences)
return new_sentences
def check_merge(self, pre_sen, cur_sen):
if len(pre_sen) > 1 and len(cur_sen) > 0:
# If it's a decimal point in the middle of a floating-point number
if pre_sen[-1] == "." and pre_sen[-2].isdigit() and cur_sen[0].isdigit():
return True
# If it's a numeric index at the beginning of a sentence, such as 1. *****\n2. *****
if (
pre_sen[-1] == "."
and pre_sen[-2].isdigit()
and cur_sen[0] not in self.line_break_characters
):
return True
# In markdown format, ! followed by [ may be an image link
if (
pre_sen[-1] == "!"
and pre_sen[-2] in self.line_break_characters
and cur_sen[0] == "["
):
return True
return False
def _merge_sentences_into_chunks(
self, sentences: List[str], min_chunk_size: int = 200
) -> List[str]:
"""Assemble into chunks according to chunk_size and overlap. Note that external guarantees ensure that the length of a single sentence does not exceed chunk_size"""
if not sentences:
return []
n_tokens = [
len(
self._tokenizer.encode(
sentence,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
)
for sentence in sentences
]
chunks = []
start_idx = 0
end_idx = start_idx + 1
cur_token_num = n_tokens[start_idx]
while start_idx < len(n_tokens):
# Tail reaches the end point,
if end_idx >= len(n_tokens):
chunk = "".join(sentences[start_idx:end_idx])
logging.debug(
"sentences[%s:%s] merged into chunk, current num_tokens: %s(%s)",
start_idx,
end_idx,
sum(n_tokens[start_idx:end_idx]),
cur_token_num,
)
chunks.append(chunk)
break
else:
# +The next sentence will not exceed chunk_size, continue to include new sentences
if cur_token_num + n_tokens[end_idx] <= self._chunk_size:
cur_token_num += n_tokens[end_idx]
end_idx += 1
# +The next sentence will exceed chunk_size, assemble the current chunk and move to the next chunk
else:
chunk = "".join(sentences[start_idx:end_idx])
logging.debug(
"sentences[%s:%s] merged into chunk, current num_tokens: %s(%s)",
start_idx,
end_idx,
sum(n_tokens[start_idx:end_idx]),
cur_token_num,
)
chunks.append(chunk)
# Next chunk: idx moves at least one position forward, start_idx allows overlap
end_idx = end_idx + 1
# Find a new starting point for start_idx that doesn't exceed the overlap
new_start_idx = end_idx - 1
overlap = 0
new_cur_token_num = n_tokens[new_start_idx]
while new_start_idx > start_idx + 1:
if (
overlap + n_tokens[new_start_idx - 1] >= self._chunk_overlap
or new_cur_token_num >= self._chunk_size
):
break
new_start_idx -= 1
overlap += n_tokens[new_start_idx]
new_cur_token_num += n_tokens[new_start_idx]
start_idx = new_start_idx
cur_token_num = new_cur_token_num
if len(chunks) > 1 and len(chunks[-1]) < min_chunk_size:
logging.warning(
"The last chunk length %s is less than %s, merge with the previous chunk",
len(chunks[-1]),
min_chunk_size,
)
last_chunk = chunks.pop()
chunks[-1] += last_chunk
chunks = [chunk for chunk in chunks if chunk.strip()]
return chunks
def _force_split_to_chunks(
self, text: str, url_strings: List[str] = []
) -> List[str]:
# TODO: In the future, consider adding forced splitting logic, such as: if a single sentence is too long, split by punctuation within the sentence, trying to preserve links and other data that require complete information
"""If a single sentence is too long, it can only be forcibly split, split by punctuation within the sentence, trying to preserve links and other data that require complete information"""
splits = []
input_ids = self._tokenizer.encode(
text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
if len(input_ids) < self._chunk_size:
return [text]
if text[-1] not in self.sentence_terminators + self.intra_sentence_delimiters:
text += self.sentence_terminators[0]
cur_sentence, cur_sentence_len = "", 0
sub_sentence = ""
for c in text:
sub_sentence += c
if c in self.intra_sentence_delimiters + self.sentence_terminators:
sub_sentence_len = len(self._tokenizer.encode(sub_sentence))
if (
cur_sentence_len + sub_sentence_len
> self._chunk_size - self._chunk_overlap
):
if cur_sentence:
splits.append(cur_sentence)
cur_sentence, cur_sentence_len = sub_sentence, sub_sentence_len
else:
# This indicates that sub_sentence is too long, at this point directly follow the forced splitting logic based on tokens
_splits = self.safe_split(sub_sentence, url_strings)
splits.extend(_splits[:-1])
cur_sentence, cur_sentence_len = _splits[-1], len(_splits[-1])
else:
cur_sentence += sub_sentence
cur_sentence_len += sub_sentence_len
sub_sentence = ""
if cur_sentence:
splits.append(cur_sentence)
return splits
def safe_split(self, sub_sentence: str, url_strings: List[str] = []) -> List[str]:
sub_sentence_tokens = self._tokenizer.encode(sub_sentence)
# Find the position intervals of all strings in url_strings
url_string_intervals = []
for url_string in url_strings:
encoded_url_string = self._tokenizer.encode(url_string)
# Use find_sublist_indices to find all position intervals
url_string_intervals.extend(
find_sublist_indices(sub_sentence_tokens, encoded_url_string)
)
_splits = []
i = 0
while i < len(sub_sentence_tokens):
if i + self._chunk_size >= len(sub_sentence_tokens):
slice_end = len(sub_sentence_tokens)
else:
slice_end = i + self._chunk_size - self._chunk_overlap
# Determine if the split interval overlaps with any important string intervals
for s_begin, s_end in url_string_intervals:
if i < s_end <= slice_end or i < s_begin < slice_end:
slice_end = max(slice_end, s_end)
# Split and record the current chunk
_splits.append(self._tokenizer.decode(sub_sentence_tokens[i:slice_end]))
# Move to the starting point of the next chunk
i = slice_end
return _splits
def get_summarize_title_keywords(responses):
# Clean LLM generated content to obtain summarized text titles, abstracts, and keywords
pattern = re.compile(r"\{.*(\}|\]|\,)", re.DOTALL)
gen_texts = [each.choices[0].message.content for each in responses]
logging.info("gen_texts: %s", gen_texts)
results = []
for res in gen_texts:
try:
# Match against the pattern
matches = list(pattern.finditer(res))
if not matches:
results.append(("", "", []))
else:
answer = matches[0].group(0)
content = answer.strip().strip(",")
content += "]" * (content.count("[") - content.count("]"))
content += "}" * (content.count("{") - content.count("}"))
d = json.loads(res)
results.append(
(d.get("title", ""), d.get("summary", ""), d.get("keywords", []))
)
except json.JSONDecodeError:
logging.warning("JSON parsing failed, returning empty list")
results.append(("", "", []))
return results