Spaces:
Runtime error
Runtime error
| """ This file contains the code for calling all LLM APIs. """ | |
| from pathlib import Path | |
| from .schema import TooLongPromptError, LLMError | |
| from functools import partial | |
| from transformers import AutoTokenizer | |
| import transformers | |
| import torch | |
| import os | |
| import time | |
| # try: | |
| # from huggingface_hub import login | |
| # login(os.environ["HF_TOKEN"]) | |
| # except Exception as e: | |
| # print(e) | |
| # print("Could not load hugging face token HF_TOKEN from environ") | |
| try: | |
| import anthropic | |
| # setup anthropic API key | |
| anthropic_client = anthropic.Anthropic(api_key=os.environ['CLAUDE_API_KEY']) | |
| except Exception as e: | |
| print(e) | |
| print("Could not load anthropic API key CLAUDE_API_KEY from environ") | |
| try: | |
| import openai | |
| openai_client = openai.OpenAI() | |
| except Exception as e: | |
| print(e) | |
| print("Could not load OpenAI API key OPENAI_API_KEY from environ") | |
| class LlamaAgent: | |
| def __init__( | |
| self, | |
| model_name, | |
| temperature: float = 0.5, | |
| top_p: float = None, | |
| max_batch_size: int = 1, | |
| max_gen_len = 2000, | |
| ): | |
| from huggingface_hub import login | |
| login() | |
| model = f"meta-llama/{model_name}" | |
| self.pipeline = transformers.pipeline( | |
| "text-generation", | |
| model=model, | |
| model_kwargs={"torch_dtype": torch.bfloat16}, | |
| device_map="auto", | |
| ) | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| self.max_gen_len = max_gen_len | |
| def complete_text( | |
| self, | |
| prompts: list[str], | |
| max_gen_len=None, | |
| temperature=None, | |
| top_p=None, | |
| num_responses=1, | |
| ) -> list[str]: | |
| if max_gen_len is None: | |
| max_gen_len = self.max_gen_len | |
| if temperature is None: | |
| temperature = self.temperature | |
| if top_p is None: | |
| top_p = self.top_p | |
| results = [] | |
| for prompt in prompts: | |
| seqs = self.pipeline( | |
| [{"role": "user", "content": prompt}], | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| num_return_sequences=num_responses, | |
| max_new_tokens=max_gen_len, | |
| ) | |
| seqs = [s["generated_text"][-1]["content"] for s in seqs] | |
| results += seqs | |
| return results | |
| agent_cache = {} | |
| def complete_text_openai(prompt, stop_sequences=[], model="gpt-3.5-turbo", max_tokens_to_sample=2000, temperature=0.2): | |
| """ Call the OpenAI API to complete a prompt.""" | |
| raw_request = { | |
| "model": model, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens_to_sample, | |
| "stop": stop_sequences or None, # API doesn't like empty list | |
| } | |
| messages = [{"role": "user", "content": prompt}] | |
| response = openai_client.chat.completions.create(messages=messages, **raw_request) | |
| completion = response.choices[0].message.content | |
| return completion | |
| def complete_text_claude(prompt, stop_sequences=[anthropic.HUMAN_PROMPT], model="claude-v1", max_tokens_to_sample=2000, temperature=0.5): | |
| """ Call the Claude API to complete a prompt.""" | |
| ai_prompt = anthropic.AI_PROMPT | |
| try: | |
| while True: | |
| try: | |
| message = anthropic_client.messages.create( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": prompt, | |
| } | |
| ], | |
| model=model, | |
| stop_sequences=stop_sequences, | |
| temperature=temperature, | |
| max_tokens=max_tokens_to_sample, | |
| ) | |
| except anthropic.RateLimitError: | |
| time.sleep(0.1) | |
| continue | |
| except anthropic.InternalServerError as e: | |
| pass | |
| try: | |
| completion = message.content[0].text | |
| break | |
| except: | |
| print("end_turn???") | |
| pass | |
| except anthropic.APIStatusError as e: | |
| print(e) | |
| raise TooLongPromptError() | |
| except Exception as e: | |
| raise LLMError(e) | |
| return completion | |
| def complete_multi_text( | |
| prompts: str, model: str, | |
| max_tokens_to_sample=None, | |
| temperature=0.5, | |
| top_p=None, | |
| responses_per_request=1, | |
| ) -> list[str]: | |
| """ Complete text using the specified model with appropriate API. """ | |
| if model.startswith("claude"): | |
| completions = [] | |
| for prompt in prompts: | |
| for _ in range(responses_per_request): | |
| completion = complete_text_claude( | |
| prompt, | |
| stop_sequences=[anthropic.HUMAN_PROMPT, "Observation:"], | |
| temperature=temperature, | |
| model=model, | |
| ) | |
| completions.append(completion) | |
| return completions | |
| elif model.startswith("gpt"): | |
| completions = [] | |
| for prompt in prompts: | |
| for _ in range(responses_per_request): | |
| completion = complete_text_openai( | |
| prompt, | |
| stop_sequences=[anthropic.HUMAN_PROMPT, "Observation:"], | |
| temperature=temperature, | |
| model=model, | |
| ) | |
| completions.append(completion) | |
| return completions | |
| else: #llama | |
| if model not in agent_cache: | |
| agent_cache[model] = LlamaAgent(model_name=model) | |
| completions = [] | |
| try: | |
| completions = agent_cache[model].complete_text( | |
| prompts=prompts, | |
| num_responses=responses_per_request, | |
| max_gen_len=max_tokens_to_sample, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| for _ in range(responses_per_request): | |
| completions += agent_cache[model].complete_text( | |
| prompts=prompts, | |
| ) | |
| except Exception as e: | |
| raise LLMError(e) | |
| return completions | |
| def complete_text( | |
| prompt: str, model: str, | |
| max_tokens_to_sample=2000, | |
| temperature=0.5, | |
| top_p=None, | |
| ) -> str: | |
| completion = complete_multi_text( | |
| prompts=[prompt], | |
| model=model, | |
| max_tokens_to_sample=max_tokens_to_sample, | |
| temperature=temperature, | |
| top_p=top_p, | |
| )[0] | |
| return completion | |
| # specify fast models for summarization etc | |
| FAST_MODEL = "claude-3-haiku" | |
| def complete_text_fast(prompt, *args, **kwargs): | |
| return complete_text(prompt, model=FAST_MODEL, *args, **kwargs) | |