Spaces:
Running
Running
File size: 2,754 Bytes
d02622b e83bd85 d02622b e83bd85 d02622b 36cfd9d d02622b 36cfd9d d02622b 36cfd9d d02622b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os
from typing import Any, Dict, Optional
from graphgen.bases import BaseLLMWrapper
from graphgen.models import Tokenizer
class LLMFactory:
"""
A factory class to create LLM wrapper instances based on the specified backend.
Supported backends include:
- http_api: HTTPClient
- openai_api: OpenAIClient
- ollama_api: OllamaClient
- huggingface: HuggingFaceWrapper
- sglang: SGLangWrapper
"""
@staticmethod
def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
# add tokenizer
tokenizer: Tokenizer = Tokenizer(
os.environ.get("TOKENIZER_MODEL", "cl100k_base"),
)
config["tokenizer"] = tokenizer
if backend == "http_api":
from graphgen.models.llm.api.http_client import HTTPClient
return HTTPClient(**config)
if backend in ("openai_api", "azure_openai_api"):
from graphgen.models.llm.api.openai_client import OpenAIClient
# pass in concrete backend to the OpenAIClient so that internally we can distinguish
# between OpenAI and Azure OpenAI
return OpenAIClient(**config, backend=backend)
if backend == "ollama_api":
from graphgen.models.llm.api.ollama_client import OllamaClient
return OllamaClient(**config)
if backend == "huggingface":
from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper
return HuggingFaceWrapper(**config)
if backend == "sglang":
from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper
return SGLangWrapper(**config)
# if backend == "vllm":
# from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
#
# return VLLMWrapper(**config)
raise NotImplementedError(f"Backend {backend} is not implemented yet.")
def _load_env_group(prefix: str) -> Dict[str, Any]:
"""
Collect environment variables with the given prefix into a dictionary,
stripping the prefix from the keys.
"""
return {
k[len(prefix) :].lower(): v
for k, v in os.environ.items()
if k.startswith(prefix)
}
def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
if model_type == "synthesizer":
prefix = "SYNTHESIZER_"
elif model_type == "trainee":
prefix = "TRAINEE_"
else:
raise NotImplementedError(f"Model type {model_type} is not implemented yet.")
config = _load_env_group(prefix)
# if config is empty, return None
if not config:
return None
backend = config.pop("backend")
llm_wrapper = LLMFactory.create_llm_wrapper(backend, config)
return llm_wrapper
|