File size: 2,476 Bytes
3a3b216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283e483
 
 
 
3a3b216
 
 
 
 
 
 
283e483
3a3b216
 
 
 
 
 
 
0b9d8c7
 
 
283e483
0b9d8c7
 
 
283e483
0b9d8c7
3a3b216
0b9d8c7
 
 
 
8ad3d05
0b9d8c7
 
 
 
 
 
3a3b216
0b9d8c7
 
 
3a3b216
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from functools import lru_cache
from typing import Union

from tqdm.asyncio import tqdm as tqdm_async

from graphgen.models import (
    ChineseRecursiveTextSplitter,
    RecursiveCharacterSplitter,
    Tokenizer,
)
from graphgen.utils import compute_content_hash, detect_main_language

_MAPPING = {
    "en": RecursiveCharacterSplitter,
    "zh": ChineseRecursiveTextSplitter,
}

SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter]


@lru_cache(maxsize=None)
def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT:
    cls = _MAPPING[language]
    kwargs = dict(frozen_kwargs)
    return cls(**kwargs)


def split_chunks(text: str, language: str = "en", **kwargs) -> list:
    if language not in _MAPPING:
        raise ValueError(
            f"Unsupported language: {language}. "
            f"Supported languages are: {list(_MAPPING.keys())}"
        )
    frozen_kwargs = frozenset(
        (k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items()
    )
    splitter = _get_splitter(language, frozen_kwargs)
    return splitter.split_text(text)


async def chunk_documents(
    new_docs: dict,
    tokenizer_instance: Tokenizer = None,
    progress_bar=None,
    **kwargs,
) -> dict:
    inserting_chunks = {}
    cur_index = 1
    doc_number = len(new_docs)
    async for doc_key, doc in tqdm_async(
        new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
    ):
        doc_type = doc.get("type")
        if doc_type == "text":
            doc_language = detect_main_language(doc["content"])

            text_chunks = split_chunks(
                doc["content"],
                language=doc_language,
                **kwargs,
            )

            chunks = {
                compute_content_hash(txt, prefix="chunk-"): {
                    "content": txt,
                    "type": "text",
                    "_full_docs_id": doc_key,
                    "length": len(tokenizer_instance.encode(txt))
                    if tokenizer_instance
                    else len(txt),
                    "language": doc_language,
                }
                for txt in text_chunks
            }
        else:
            chunks = {doc_key.replace("doc-", f"{doc_type}-"): {**doc}}

        inserting_chunks.update(chunks)

        if progress_bar is not None:
            progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
            cur_index += 1

    return inserting_chunks