Spaces:
Running
Running
| from abc import ABC, abstractmethod | |
| import httpx | |
| import json | |
| from config import CHAT_MODEL_SPECS, OPEN_AI_KEY, OPEN_AI_ENTRYPOINT, OPEN_AI_PROVIDER | |
| class ModelProvider(ABC): | |
| """ | |
| Abstract base class for a model provider. This allows for different | |
| backends (e.g., local, OpenAI API) to be used interchangeably. | |
| """ | |
| def __init__(self, provider_name): | |
| self.provider_name = provider_name | |
| def get_response(self, model_id, message, chat_history): | |
| """ | |
| Generates a response from a model. | |
| :param model_id: The internal model ID to use. | |
| :param message: The user's message. | |
| :param chat_history: The current chat history. | |
| :return: A generator that yields the response. | |
| """ | |
| pass | |
| class OpenAICompatibleProvider(ModelProvider): | |
| """ | |
| A model provider for any OpenAI compatible API. | |
| """ | |
| def __init__(self, provider_name="openai_compatible"): | |
| super().__init__(provider_name) | |
| self.api_key = OPEN_AI_KEY | |
| self.api_base = OPEN_AI_ENTRYPOINT | |
| if not self.api_key or not self.api_base: | |
| print("Warning: OPEN_AI_KEY or OPEN_AI_ENTRYPOINT not found in environment.") | |
| def get_response(self, model_id, message, chat_history, system_prompt, temperature=0.7): | |
| """ | |
| Makes a real API call to an OpenAI compatible API and streams the response. | |
| """ | |
| print(f"DEBUG: Received system_prompt: {system_prompt}, temperature: {temperature}") # Debug print | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| # Build message history for API call | |
| messages_for_api = [] | |
| if system_prompt: # 如果有系统提示词,添加到最前面 | |
| messages_for_api.append({"role": "system", "content": system_prompt}) | |
| if chat_history: | |
| for item in chat_history: | |
| if isinstance(item, dict) and "role" in item and "content" in item: | |
| messages_for_api.append(item) | |
| messages_for_api.append({"role": "user", "content": message}) | |
| json_data = { | |
| "model": model_id, | |
| "messages": messages_for_api, # Use the new list | |
| "stream": True, | |
| "temperature": temperature, | |
| } | |
| # Append user's message to chat_history for UI display | |
| chat_history.append({"role": "user", "content": message}) | |
| # Initialize assistant's response in chat_history | |
| chat_history.append({"role": "assistant", "content": ""}) # Placeholder for assistant's streaming response | |
| # 日志输出 - 在这里打印完整的请求数据(system, history, user, model_id) | |
| print("\n>>> DEBUG: get_response") | |
| print(">>> DEBUG: Sending request to OpenAI-compatible API") | |
| print(">>> : System prompt:", repr(system_prompt)) | |
| print(">>> : Chat history:", repr(chat_history)) | |
| print(">>> : User message:", repr(message)) | |
| print(">>> : Model ID:", repr(model_id)) | |
| print(">>> : Temperature:", repr(temperature)) | |
| full_response = "" | |
| try: | |
| with httpx.stream( | |
| "POST", | |
| f"{self.api_base}/chat/completions", | |
| headers=headers, | |
| json=json_data, | |
| timeout=120, | |
| ) as response: | |
| response.raise_for_status() | |
| for chunk in response.iter_lines(): | |
| if chunk.startswith("data:"): | |
| chunk = chunk[5:].strip() | |
| if chunk == "[DONE]": | |
| break | |
| try: | |
| data = json.loads(chunk) | |
| if "choices" in data and data["choices"]: | |
| delta = data["choices"][0].get("delta", {}) | |
| content_chunk = delta.get("content") | |
| if content_chunk: | |
| full_response += content_chunk | |
| chat_history[-1]["content"] += content_chunk | |
| yield chat_history | |
| except json.JSONDecodeError: | |
| print(f"Error decoding JSON chunk: {chunk}") | |
| print(f"DEBUG: Full code response: {full_response}") | |
| except Exception as e: | |
| print(f"XXX DEBUG: Error during API call: {e}") | |
| # Ensure the last message (assistant's placeholder) is updated with the error | |
| if chat_history and chat_history[-1]["role"] == "assistant": | |
| chat_history[-1]["content"] = f"An error occurred: {e}" | |
| else: | |
| chat_history.append({"role": "assistant", "content": f"An error occurred: {e}"}) | |
| yield chat_history | |
| def get_code_response(self, model_id, system_prompt, user_prompt, temperature=0.7): | |
| """ | |
| Makes a real API call for code generation and streams the response. | |
| """ | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| messages_for_api = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| json_data = { | |
| "model": model_id, | |
| "messages": messages_for_api, | |
| "stream": True, | |
| "temperature": temperature, | |
| } | |
| print("\n>>> DEBUG: get_code_response") | |
| print(">>> DEBUG: Sending request to OpenAI-compatible API") | |
| print(">>> : System prompt:", repr(system_prompt)) | |
| print(">>> : User message:", repr(user_prompt)) | |
| print(">>> : Model ID:", repr(model_id)) | |
| print(">>> : Temperature:", repr(temperature)) | |
| full_response = "" | |
| try: | |
| with httpx.stream("POST", f"{self.api_base}/chat/completions", headers=headers, json=json_data, timeout=120) as response: | |
| response.raise_for_status() | |
| for chunk in response.iter_lines(): | |
| if chunk.startswith("data:"): | |
| chunk = chunk[5:].strip() | |
| if chunk == "[DONE]": | |
| break | |
| try: | |
| data = json.loads(chunk) | |
| if "choices" in data and data["choices"]: | |
| delta = data["choices"][0].get("delta", {}) | |
| content_chunk = delta.get("content") | |
| if content_chunk: | |
| full_response += content_chunk | |
| yield content_chunk | |
| except json.JSONDecodeError: | |
| print(f"Error decoding JSON chunk: {chunk}") | |
| print(f"DEBUG: Full code response: {full_response}") | |
| except Exception as e: | |
| print(f"Error during API call: {e}") | |
| yield f"An error occurred: {e}" | |
| class ModelHandler: | |
| """ | |
| Manages different models and providers, acting as a facade for the UI. | |
| """ | |
| def __init__(self): | |
| """ | |
| Initializes the ModelHandler with the global CHAT_MODEL_SPECS. | |
| """ | |
| self.config = CHAT_MODEL_SPECS | |
| self.providers = { | |
| "openai_compatible": OpenAICompatibleProvider() | |
| } | |
| self.api_provider_brand = OPEN_AI_PROVIDER | |
| def get_response(self, model_constant, message, chat_history, system_prompt, temperature=0.7): | |
| """ | |
| Gets a response from the appropriate model and provider. | |
| :param model_constant: The constant name of the model (e.g., LING_MODEL_A). | |
| :param message: The user's message. | |
| :param chat_history: The current chat history. | |
| :param system_prompt: The system prompt to guide the model's behavior. | |
| :param temperature: The temperature for the model. | |
| :return: A generator that yields the response. | |
| """ | |
| model_spec = self.config.get(model_constant, {}) | |
| provider_name = model_spec.get("provider") | |
| model_id = model_spec.get("model_id") | |
| # Handle the case where chat_history might be None | |
| if chat_history is None: | |
| chat_history = [] | |
| if not provider_name or provider_name not in self.providers: | |
| full_response = f"Error: Model '{model_constant}' or its provider '{provider_name}' not configured." | |
| chat_history.append([message, full_response]) | |
| yield chat_history | |
| return | |
| provider = self.providers[provider_name] | |
| yield from provider.get_response(model_id, message, chat_history, system_prompt, temperature) | |
| def generate_code(self, system_prompt, user_prompt, model_choice): | |
| """ | |
| Generates code using the specified model. | |
| """ | |
| model_constant = next((k for k, v in CHAT_MODEL_SPECS.items() if v["display_name"] == model_choice), None) | |
| if not model_constant: | |
| # Fallback if display name not found, maybe model_choice is the constant itself | |
| model_constant = model_choice if model_choice in CHAT_MODEL_SPECS else "LING_1T" | |
| model_spec = self.config.get(model_constant, {}) | |
| provider_name = model_spec.get("provider") | |
| model_id = model_spec.get("model_id") | |
| if not provider_name or provider_name not in self.providers: | |
| yield f"Error: Model '{model_constant}' or its provider '{provider_name}' not configured." | |
| return | |
| provider = self.providers[provider_name] | |
| yield from provider.get_code_response(model_id, system_prompt, user_prompt) | |