File size: 2,754 Bytes
d02622b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e83bd85
d02622b
e83bd85
 
 
d02622b
 
 
 
 
 
 
 
36cfd9d
 
d02622b
36cfd9d
d02622b
36cfd9d
 
 
 
d02622b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
from typing import Any, Dict, Optional

from graphgen.bases import BaseLLMWrapper
from graphgen.models import Tokenizer


class LLMFactory:
    """
    A factory class to create LLM wrapper instances based on the specified backend.
    Supported backends include:
    - http_api: HTTPClient
    - openai_api: OpenAIClient
    - ollama_api: OllamaClient
    - huggingface: HuggingFaceWrapper
    - sglang: SGLangWrapper
    """

    @staticmethod
    def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
        # add tokenizer
        tokenizer: Tokenizer = Tokenizer(
            os.environ.get("TOKENIZER_MODEL", "cl100k_base"),
        )
        config["tokenizer"] = tokenizer
        if backend == "http_api":
            from graphgen.models.llm.api.http_client import HTTPClient

            return HTTPClient(**config)
        if backend in ("openai_api", "azure_openai_api"):
            from graphgen.models.llm.api.openai_client import OpenAIClient
            # pass in concrete backend to the OpenAIClient so that internally we can distinguish
            # between OpenAI and Azure OpenAI
            return OpenAIClient(**config, backend=backend)
        if backend == "ollama_api":
            from graphgen.models.llm.api.ollama_client import OllamaClient

            return OllamaClient(**config)
        if backend == "huggingface":
            from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper

            return HuggingFaceWrapper(**config)
        if backend == "sglang":
            from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper

            return SGLangWrapper(**config)

        # if backend == "vllm":
        #     from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
        #
        #     return VLLMWrapper(**config)

        raise NotImplementedError(f"Backend {backend} is not implemented yet.")


def _load_env_group(prefix: str) -> Dict[str, Any]:
    """
    Collect environment variables with the given prefix into a dictionary,
    stripping the prefix from the keys.
    """
    return {
        k[len(prefix) :].lower(): v
        for k, v in os.environ.items()
        if k.startswith(prefix)
    }


def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
    if model_type == "synthesizer":
        prefix = "SYNTHESIZER_"
    elif model_type == "trainee":
        prefix = "TRAINEE_"
    else:
        raise NotImplementedError(f"Model type {model_type} is not implemented yet.")
    config = _load_env_group(prefix)
    # if config is empty, return None
    if not config:
        return None
    backend = config.pop("backend")
    llm_wrapper = LLMFactory.create_llm_wrapper(backend, config)
    return llm_wrapper