| import json | |
| import os | |
| from typing import Dict, List, Any | |
| from llama_cpp import Llama | |
| import gemma_tools as gem | |
| MAX_TOKENS=8192 | |
| class EndpointHandler(): | |
| def __init__(self, data): | |
| self.model = Llama.from_pretrained("lmstudio-ai/gemma-2b-it-GGUF", filename="gemma-2b-it-q4_k_m.gguf", n_ctx=8192) | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| args = gem.get_args_or_none(data) | |
| fmat = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{prompt} <endofturn>\n<startofturn>model" | |
| print(args, fmat) | |
| if not args[0]: | |
| return { | |
| "status": args["status"], | |
| "message": args["description"] | |
| } | |
| try: | |
| fmat = fmat.format(system_prompt = args["system_prompt"], prompt = args["inputs"]) | |
| except Exception as e: | |
| return json.dumps({ | |
| "status": "error", | |
| "reason": "invalid format" | |
| }) | |
| max_length = data.pop("max_length", 512) | |
| try: | |
| max_length = int(max_length) | |
| except Exception as e: | |
| return json.dumps({ | |
| "status": "error", | |
| "reason": "max_length was passed as something that was absolutely not a plain old int" | |
| }) | |
| res = self.model(fmat, temperature=args["temperature"], top_p=args["top_p"], top_k=args["top_k"], max_tokens=max_length) | |
| return res |