ling-series-spaces / model_handler.py
GitHub Action
Sync ling-space changes from GitHub commit 9773e54
1b21038
from abc import ABC, abstractmethod
import httpx
import json
from config import CHAT_MODEL_SPECS, OPEN_AI_KEY, OPEN_AI_ENTRYPOINT, OPEN_AI_PROVIDER
class ModelProvider(ABC):
"""
Abstract base class for a model provider. This allows for different
backends (e.g., local, OpenAI API) to be used interchangeably.
"""
def __init__(self, provider_name):
self.provider_name = provider_name
@abstractmethod
def get_response(self, model_id, message, chat_history):
"""
Generates a response from a model.
:param model_id: The internal model ID to use.
:param message: The user's message.
:param chat_history: The current chat history.
:return: A generator that yields the response.
"""
pass
class OpenAICompatibleProvider(ModelProvider):
"""
A model provider for any OpenAI compatible API.
"""
def __init__(self, provider_name="openai_compatible"):
super().__init__(provider_name)
self.api_key = OPEN_AI_KEY
self.api_base = OPEN_AI_ENTRYPOINT
if not self.api_key or not self.api_base:
print("Warning: OPEN_AI_KEY or OPEN_AI_ENTRYPOINT not found in environment.")
def get_response(self, model_id, message, chat_history, system_prompt, temperature=0.7):
"""
Makes a real API call to an OpenAI compatible API and streams the response.
"""
print(f"DEBUG: Received system_prompt: {system_prompt}, temperature: {temperature}") # Debug print
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# Build message history for API call
messages_for_api = []
if system_prompt: # 如果有系统提示词,添加到最前面
messages_for_api.append({"role": "system", "content": system_prompt})
if chat_history:
for item in chat_history:
if isinstance(item, dict) and "role" in item and "content" in item:
messages_for_api.append(item)
messages_for_api.append({"role": "user", "content": message})
json_data = {
"model": model_id,
"messages": messages_for_api, # Use the new list
"stream": True,
"temperature": temperature,
}
# Append user's message to chat_history for UI display
chat_history.append({"role": "user", "content": message})
# Initialize assistant's response in chat_history
chat_history.append({"role": "assistant", "content": ""}) # Placeholder for assistant's streaming response
# 日志输出 - 在这里打印完整的请求数据(system, history, user, model_id)
print("\n>>> DEBUG: get_response")
print(">>> DEBUG: Sending request to OpenAI-compatible API")
print(">>> : System prompt:", repr(system_prompt))
print(">>> : Chat history:", repr(chat_history))
print(">>> : User message:", repr(message))
print(">>> : Model ID:", repr(model_id))
print(">>> : Temperature:", repr(temperature))
full_response = ""
try:
with httpx.stream(
"POST",
f"{self.api_base}/chat/completions",
headers=headers,
json=json_data,
timeout=120,
) as response:
response.raise_for_status()
for chunk in response.iter_lines():
if chunk.startswith("data:"):
chunk = chunk[5:].strip()
if chunk == "[DONE]":
break
try:
data = json.loads(chunk)
if "choices" in data and data["choices"]:
delta = data["choices"][0].get("delta", {})
content_chunk = delta.get("content")
if content_chunk:
full_response += content_chunk
chat_history[-1]["content"] += content_chunk
yield chat_history
except json.JSONDecodeError:
print(f"Error decoding JSON chunk: {chunk}")
print(f"DEBUG: Full code response: {full_response}")
except Exception as e:
print(f"XXX DEBUG: Error during API call: {e}")
# Ensure the last message (assistant's placeholder) is updated with the error
if chat_history and chat_history[-1]["role"] == "assistant":
chat_history[-1]["content"] = f"An error occurred: {e}"
else:
chat_history.append({"role": "assistant", "content": f"An error occurred: {e}"})
yield chat_history
def get_code_response(self, model_id, system_prompt, user_prompt, temperature=0.7):
"""
Makes a real API call for code generation and streams the response.
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
messages_for_api = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
json_data = {
"model": model_id,
"messages": messages_for_api,
"stream": True,
"temperature": temperature,
}
print("\n>>> DEBUG: get_code_response")
print(">>> DEBUG: Sending request to OpenAI-compatible API")
print(">>> : System prompt:", repr(system_prompt))
print(">>> : User message:", repr(user_prompt))
print(">>> : Model ID:", repr(model_id))
print(">>> : Temperature:", repr(temperature))
full_response = ""
try:
with httpx.stream("POST", f"{self.api_base}/chat/completions", headers=headers, json=json_data, timeout=120) as response:
response.raise_for_status()
for chunk in response.iter_lines():
if chunk.startswith("data:"):
chunk = chunk[5:].strip()
if chunk == "[DONE]":
break
try:
data = json.loads(chunk)
if "choices" in data and data["choices"]:
delta = data["choices"][0].get("delta", {})
content_chunk = delta.get("content")
if content_chunk:
full_response += content_chunk
yield content_chunk
except json.JSONDecodeError:
print(f"Error decoding JSON chunk: {chunk}")
print(f"DEBUG: Full code response: {full_response}")
except Exception as e:
print(f"Error during API call: {e}")
yield f"An error occurred: {e}"
class ModelHandler:
"""
Manages different models and providers, acting as a facade for the UI.
"""
def __init__(self):
"""
Initializes the ModelHandler with the global CHAT_MODEL_SPECS.
"""
self.config = CHAT_MODEL_SPECS
self.providers = {
"openai_compatible": OpenAICompatibleProvider()
}
self.api_provider_brand = OPEN_AI_PROVIDER
def get_response(self, model_constant, message, chat_history, system_prompt, temperature=0.7):
"""
Gets a response from the appropriate model and provider.
:param model_constant: The constant name of the model (e.g., LING_MODEL_A).
:param message: The user's message.
:param chat_history: The current chat history.
:param system_prompt: The system prompt to guide the model's behavior.
:param temperature: The temperature for the model.
:return: A generator that yields the response.
"""
model_spec = self.config.get(model_constant, {})
provider_name = model_spec.get("provider")
model_id = model_spec.get("model_id")
# Handle the case where chat_history might be None
if chat_history is None:
chat_history = []
if not provider_name or provider_name not in self.providers:
full_response = f"Error: Model '{model_constant}' or its provider '{provider_name}' not configured."
chat_history.append([message, full_response])
yield chat_history
return
provider = self.providers[provider_name]
yield from provider.get_response(model_id, message, chat_history, system_prompt, temperature)
def generate_code(self, system_prompt, user_prompt, model_choice):
"""
Generates code using the specified model.
"""
model_constant = next((k for k, v in CHAT_MODEL_SPECS.items() if v["display_name"] == model_choice), None)
if not model_constant:
# Fallback if display name not found, maybe model_choice is the constant itself
model_constant = model_choice if model_choice in CHAT_MODEL_SPECS else "LING_1T"
model_spec = self.config.get(model_constant, {})
provider_name = model_spec.get("provider")
model_id = model_spec.get("model_id")
if not provider_name or provider_name not in self.providers:
yield f"Error: Model '{model_constant}' or its provider '{provider_name}' not configured."
return
provider = self.providers[provider_name]
yield from provider.get_code_response(model_id, system_prompt, user_prompt)