File size: 1,714 Bytes
799ac7c
 
 
1189434
 
799ac7c
 
 
 
 
 
 
 
 
1189434
 
799ac7c
 
1189434
799ac7c
 
 
 
 
 
 
 
 
 
 
 
9e67c3b
799ac7c
9e67c3b
799ac7c
 
 
 
 
 
 
1189434
 
799ac7c
 
 
 
 
1189434
 
799ac7c
 
 
9e67c3b
799ac7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import asyncio
from typing import List, Tuple

import gradio as gr

from graphgen.bases import BaseGraphStorage, BaseTokenizer
from graphgen.utils import run_concurrent


async def pre_tokenize(
    graph_storage: BaseGraphStorage,
    tokenizer: BaseTokenizer,
    edges: List[Tuple],
    nodes: List[Tuple],
    progress_bar: gr.Progress = None,
    max_concurrent: int = 1000,
) -> Tuple[List, List]:
    """为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。"""
    sem = asyncio.Semaphore(max_concurrent)

    async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple:
        async with sem:
            data = obj[1] if is_node else obj[2]
            if "length" not in data:
                loop = asyncio.get_event_loop()
                data["length"] = len(
                    await loop.run_in_executor(
                        None, tokenizer.encode, data["description"]
                    )
                )
            if is_node:
                graph_storage.update_node(obj[0], obj[1])
            else:
                graph_storage.update_edge(obj[0], obj[1], obj[2])
            return obj

    new_edges, new_nodes = await asyncio.gather(
        run_concurrent(
            lambda e: _patch_and_write(e, is_node=False),
            edges,
            desc="Pre-tokenizing edges",
            unit="edge",
            progress_bar=progress_bar,
        ),
        run_concurrent(
            lambda n: _patch_and_write(n, is_node=True),
            nodes,
            desc="Pre-tokenizing nodes",
            unit="node",
            progress_bar=progress_bar,
        ),
    )

    graph_storage.index_done_callback()
    return new_edges, new_nodes