Spaces:
Running
Running
| import requests | |
| import time | |
| import os | |
| import re | |
| import json | |
| from huggingface_hub import InferenceClient | |
| # Helper function to parse the response | |
| def parse_thinking_response(response_text): | |
| """ | |
| Parses a model's response to separate the thinking process | |
| from the final answer. | |
| """ | |
| match = re.search(r"<think>(.*?)</think>(.*)", response_text, re.DOTALL) | |
| if match: | |
| thinking = match.group(1).strip() | |
| final_answer = match.group(2).strip() | |
| return thinking, final_answer | |
| else: | |
| return None, response_text.strip() | |
| def get_inference_endpoint_response( | |
| model, | |
| messages, | |
| temperature, | |
| top_p, | |
| max_tokens | |
| ): | |
| """ | |
| Serverless API (Pay-as-you-go) | |
| """ | |
| client = InferenceClient( | |
| provider="auto", | |
| api_key=os.getenv("HF_API_KEY") | |
| ) | |
| completion = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens | |
| ) | |
| # Get the raw response content | |
| raw_response = completion.choices[0].message.content | |
| return raw_response | |
| def get_custom_inference_endpoint_response( | |
| messages: list, | |
| use_expert: bool = True, | |
| tokenizer_max_length: int = 512, | |
| do_sample: bool = False, | |
| temperature: float = 0.6, | |
| top_k: int = 50, | |
| top_p: float = 0.95, | |
| num_beams: int = 1, | |
| max_new_tokens: int = 1024, | |
| **kwargs # To catch any other unused arguments | |
| ): | |
| """ | |
| Contacts a custom Hugging Face inference endpoint with retry logic. | |
| This function is tailored to a custom EndpointHandler that expects a specific | |
| payload structure: {"inputs": {"messages": [...], "settings": {...}}}. | |
| """ | |
| endpoint_url = os.getenv("HF_ENDPOINT_URL") | |
| hf_endpoint_token = os.getenv("HF_ENDPOINT_TOKEN") | |
| if not endpoint_url or not hf_endpoint_token: | |
| return "Error: HF_ENDPOINT_URL and HF_ENDPOINT_TOKEN environment variables must be set." | |
| headers = { | |
| "Authorization": f"Bearer {hf_endpoint_token}", | |
| "Content-Type": "application/json" | |
| } | |
| # --- PAYLOAD STRUCTURE FOR THE CUSTOM ENDPOINT HANDLER --- | |
| # This handler expects a 'settings' dictionary nested inside 'inputs'. | |
| settings = { | |
| "use_expert": use_expert, | |
| "tokenizer_max_length": tokenizer_max_length, | |
| "do_sample": do_sample, | |
| "temperature": temperature, | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| "num_beams": num_beams, | |
| "max_new_tokens": max_new_tokens, | |
| } | |
| # The server-side EndpointHandler is designed to handle parameter logic, | |
| # so we send all parameters from the client. | |
| # The final payload must match the nested structure the custom handler expects. | |
| payload = { | |
| "inputs": { | |
| "messages": messages, | |
| "settings": settings | |
| } | |
| } | |
| # --- Retry Logic --- | |
| max_retries = 5 | |
| wait_time = 30 # seconds to wait between retries | |
| for attempt in range(max_retries): | |
| print(f"Attempting to contact endpoint, attempt {attempt + 1}/{max_retries}...") | |
| # Log the exact payload being sent for easier debugging | |
| print(f"Payload: {json.dumps(payload, indent=2)}") | |
| try: | |
| response = requests.post(endpoint_url, headers=headers, json=payload) | |
| # Raise an exception for bad status codes (4xx or 5xx) | |
| response.raise_for_status() | |
| result = response.json() | |
| print(f"Success! Response: {result}") | |
| # The custom handler returns a dictionary with a 'response' key. | |
| # This parsing logic correctly extracts it. | |
| return result.get('response', 'Error: "response" key not found in the result.') | |
| except requests.exceptions.HTTPError as errh: | |
| # Handle specific 503 error for model loading | |
| if errh.response.status_code == 503 and attempt < max_retries - 1: | |
| print(f"Service Unavailable (503). Endpoint may be starting up. Retrying in {wait_time} seconds...") | |
| time.sleep(wait_time) | |
| else: | |
| error_message = f"HTTP Error: {errh}\nResponse: {errh.response.text}" | |
| print(error_message) | |
| return error_message | |
| except requests.exceptions.RequestException as err: | |
| error_message = f"Request Error: {err}" | |
| print(error_message) | |
| return error_message | |
| except json.JSONDecodeError: | |
| error_message = f"JSON Decode Error: Failed to parse response from server.\nResponse Text: {response.text}" | |
| print(error_message) | |
| return error_message | |
| return "Error: Failed to get a response after multiple retries." |