File size: 2,689 Bytes
d02622b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
from typing import Any, Dict, Optional

from graphgen.bases import BaseLLMWrapper
from graphgen.models import Tokenizer


class LLMFactory:
    """
    A factory class to create LLM wrapper instances based on the specified backend.
    Supported backends include:
    - http_api: HTTPClient
    - openai_api: OpenAIClient
    - ollama_api: OllamaClient
    - ollama: OllamaWrapper
    - deepspeed: DeepSpeedWrapper
    - huggingface: HuggingFaceWrapper
    - tgi: TGIWrapper
    - sglang: SGLangWrapper
    - tensorrt: TensorRTWrapper
    """

    @staticmethod
    def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
        # add tokenizer
        tokenizer: Tokenizer = Tokenizer(
            os.environ.get("TOKENIZER_MODEL", "cl100k_base"),
        )
        config["tokenizer"] = tokenizer
        if backend == "http_api":
            from graphgen.models.llm.api.http_client import HTTPClient

            return HTTPClient(**config)
        if backend == "openai_api":
            from graphgen.models.llm.api.openai_client import OpenAIClient

            return OpenAIClient(**config)
        if backend == "ollama_api":
            from graphgen.models.llm.api.ollama_client import OllamaClient

            return OllamaClient(**config)
        if backend == "huggingface":
            from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper

            return HuggingFaceWrapper(**config)
        # if backend == "sglang":
        #     from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper
        #
        #     return SGLangWrapper(**config)

        if backend == "vllm":
            from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper

            return VLLMWrapper(**config)

        raise NotImplementedError(f"Backend {backend} is not implemented yet.")


def _load_env_group(prefix: str) -> Dict[str, Any]:
    """
    Collect environment variables with the given prefix into a dictionary,
    stripping the prefix from the keys.
    """
    return {
        k[len(prefix) :].lower(): v
        for k, v in os.environ.items()
        if k.startswith(prefix)
    }


def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
    if model_type == "synthesizer":
        prefix = "SYNTHESIZER_"
    elif model_type == "trainee":
        prefix = "TRAINEE_"
    else:
        raise NotImplementedError(f"Model type {model_type} is not implemented yet.")
    config = _load_env_group(prefix)
    # if config is empty, return None
    if not config:
        return None
    backend = config.pop("backend")
    llm_wrapper = LLMFactory.create_llm_wrapper(backend, config)
    return llm_wrapper