Spaces:
Running
Running
File size: 9,919 Bytes
b931367 1b21038 b931367 74ebe5c b931367 74ebe5c b931367 74ebe5c b931367 74ebe5c b931367 74ebe5c b931367 74ebe5c b931367 74ebe5c b931367 74ebe5c b931367 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
from abc import ABC, abstractmethod
import httpx
import json
from config import CHAT_MODEL_SPECS, OPEN_AI_KEY, OPEN_AI_ENTRYPOINT, OPEN_AI_PROVIDER
class ModelProvider(ABC):
"""
Abstract base class for a model provider. This allows for different
backends (e.g., local, OpenAI API) to be used interchangeably.
"""
def __init__(self, provider_name):
self.provider_name = provider_name
@abstractmethod
def get_response(self, model_id, message, chat_history):
"""
Generates a response from a model.
:param model_id: The internal model ID to use.
:param message: The user's message.
:param chat_history: The current chat history.
:return: A generator that yields the response.
"""
pass
class OpenAICompatibleProvider(ModelProvider):
"""
A model provider for any OpenAI compatible API.
"""
def __init__(self, provider_name="openai_compatible"):
super().__init__(provider_name)
self.api_key = OPEN_AI_KEY
self.api_base = OPEN_AI_ENTRYPOINT
if not self.api_key or not self.api_base:
print("Warning: OPEN_AI_KEY or OPEN_AI_ENTRYPOINT not found in environment.")
def get_response(self, model_id, message, chat_history, system_prompt, temperature=0.7):
"""
Makes a real API call to an OpenAI compatible API and streams the response.
"""
print(f"DEBUG: Received system_prompt: {system_prompt}, temperature: {temperature}") # Debug print
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# Build message history for API call
messages_for_api = []
if system_prompt: # 如果有系统提示词,添加到最前面
messages_for_api.append({"role": "system", "content": system_prompt})
if chat_history:
for item in chat_history:
if isinstance(item, dict) and "role" in item and "content" in item:
messages_for_api.append(item)
messages_for_api.append({"role": "user", "content": message})
json_data = {
"model": model_id,
"messages": messages_for_api, # Use the new list
"stream": True,
"temperature": temperature,
}
# Append user's message to chat_history for UI display
chat_history.append({"role": "user", "content": message})
# Initialize assistant's response in chat_history
chat_history.append({"role": "assistant", "content": ""}) # Placeholder for assistant's streaming response
# 日志输出 - 在这里打印完整的请求数据(system, history, user, model_id)
print("\n>>> DEBUG: get_response")
print(">>> DEBUG: Sending request to OpenAI-compatible API")
print(">>> : System prompt:", repr(system_prompt))
print(">>> : Chat history:", repr(chat_history))
print(">>> : User message:", repr(message))
print(">>> : Model ID:", repr(model_id))
print(">>> : Temperature:", repr(temperature))
full_response = ""
try:
with httpx.stream(
"POST",
f"{self.api_base}/chat/completions",
headers=headers,
json=json_data,
timeout=120,
) as response:
response.raise_for_status()
for chunk in response.iter_lines():
if chunk.startswith("data:"):
chunk = chunk[5:].strip()
if chunk == "[DONE]":
break
try:
data = json.loads(chunk)
if "choices" in data and data["choices"]:
delta = data["choices"][0].get("delta", {})
content_chunk = delta.get("content")
if content_chunk:
full_response += content_chunk
chat_history[-1]["content"] += content_chunk
yield chat_history
except json.JSONDecodeError:
print(f"Error decoding JSON chunk: {chunk}")
print(f"DEBUG: Full code response: {full_response}")
except Exception as e:
print(f"XXX DEBUG: Error during API call: {e}")
# Ensure the last message (assistant's placeholder) is updated with the error
if chat_history and chat_history[-1]["role"] == "assistant":
chat_history[-1]["content"] = f"An error occurred: {e}"
else:
chat_history.append({"role": "assistant", "content": f"An error occurred: {e}"})
yield chat_history
def get_code_response(self, model_id, system_prompt, user_prompt, temperature=0.7):
"""
Makes a real API call for code generation and streams the response.
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
messages_for_api = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
json_data = {
"model": model_id,
"messages": messages_for_api,
"stream": True,
"temperature": temperature,
}
print("\n>>> DEBUG: get_code_response")
print(">>> DEBUG: Sending request to OpenAI-compatible API")
print(">>> : System prompt:", repr(system_prompt))
print(">>> : User message:", repr(user_prompt))
print(">>> : Model ID:", repr(model_id))
print(">>> : Temperature:", repr(temperature))
full_response = ""
try:
with httpx.stream("POST", f"{self.api_base}/chat/completions", headers=headers, json=json_data, timeout=120) as response:
response.raise_for_status()
for chunk in response.iter_lines():
if chunk.startswith("data:"):
chunk = chunk[5:].strip()
if chunk == "[DONE]":
break
try:
data = json.loads(chunk)
if "choices" in data and data["choices"]:
delta = data["choices"][0].get("delta", {})
content_chunk = delta.get("content")
if content_chunk:
full_response += content_chunk
yield content_chunk
except json.JSONDecodeError:
print(f"Error decoding JSON chunk: {chunk}")
print(f"DEBUG: Full code response: {full_response}")
except Exception as e:
print(f"Error during API call: {e}")
yield f"An error occurred: {e}"
class ModelHandler:
"""
Manages different models and providers, acting as a facade for the UI.
"""
def __init__(self):
"""
Initializes the ModelHandler with the global CHAT_MODEL_SPECS.
"""
self.config = CHAT_MODEL_SPECS
self.providers = {
"openai_compatible": OpenAICompatibleProvider()
}
self.api_provider_brand = OPEN_AI_PROVIDER
def get_response(self, model_constant, message, chat_history, system_prompt, temperature=0.7):
"""
Gets a response from the appropriate model and provider.
:param model_constant: The constant name of the model (e.g., LING_MODEL_A).
:param message: The user's message.
:param chat_history: The current chat history.
:param system_prompt: The system prompt to guide the model's behavior.
:param temperature: The temperature for the model.
:return: A generator that yields the response.
"""
model_spec = self.config.get(model_constant, {})
provider_name = model_spec.get("provider")
model_id = model_spec.get("model_id")
# Handle the case where chat_history might be None
if chat_history is None:
chat_history = []
if not provider_name or provider_name not in self.providers:
full_response = f"Error: Model '{model_constant}' or its provider '{provider_name}' not configured."
chat_history.append([message, full_response])
yield chat_history
return
provider = self.providers[provider_name]
yield from provider.get_response(model_id, message, chat_history, system_prompt, temperature)
def generate_code(self, system_prompt, user_prompt, model_choice):
"""
Generates code using the specified model.
"""
model_constant = next((k for k, v in CHAT_MODEL_SPECS.items() if v["display_name"] == model_choice), None)
if not model_constant:
# Fallback if display name not found, maybe model_choice is the constant itself
model_constant = model_choice if model_choice in CHAT_MODEL_SPECS else "LING_1T"
model_spec = self.config.get(model_constant, {})
provider_name = model_spec.get("provider")
model_id = model_spec.get("model_id")
if not provider_name or provider_name not in self.providers:
yield f"Error: Model '{model_constant}' or its provider '{provider_name}' not configured."
return
provider = self.providers[provider_name]
yield from provider.get_code_response(model_id, system_prompt, user_prompt)
|