MedAI_VLM / src /server.py
Jesteban247's picture
Upload 88 files
d5f2660 verified
"""VLM client with file-based session history - supports both local GGUF and HF Spaces"""
import aiohttp
import asyncio
import json
import base64
import os
from typing import Optional, Dict, Any, AsyncGenerator
from .logger import setup_logger
from .config import IS_HF_SPACE, MAX_CONCURRENT_USERS
from .session_manager import session_manager
# Conditional imports based on environment
if not IS_HF_SPACE:
from .config import SERVER_URL, REQUEST_TIMEOUT, CONNECT_TIMEOUT
SEMAPHORE = asyncio.Semaphore(MAX_CONCURRENT_USERS)
logger = setup_logger(__name__)
if IS_HF_SPACE:
logger.info("🌐 Running in HF Spaces mode (transformers)")
else:
logger.info("πŸ’» Running in local GGUF mode (llama-cpp)")
def get_active_sessions() -> list:
"""Get list of currently active sessions"""
return session_manager.get_active_sessions()
def cleanup_expired_sessions() -> int:
"""Cleanup sessions that exceeded TTL"""
return session_manager.cleanup_expired_sessions()
async def _encode_image(path: str) -> tuple[str, str]:
"""
Encode image to base64 and detect format
Returns: (base64_string, mime_type)
"""
def _encode():
if not os.path.exists(path):
logger.warning(f"Image not found: {path}")
return None, None
with open(path, "rb") as f:
img_bytes = f.read()
# Detect image format from magic bytes
if img_bytes.startswith(b'\x89PNG'):
mime_type = "image/png"
elif img_bytes.startswith(b'\xFF\xD8\xFF'):
mime_type = "image/jpeg"
elif img_bytes.startswith(b'GIF'):
mime_type = "image/gif"
elif img_bytes.startswith(b'RIFF') and b'WEBP' in img_bytes[:12]:
mime_type = "image/webp"
else:
mime_type = "image/jpeg" # Default fallback
return base64.b64encode(img_bytes).decode(), mime_type
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _encode)
def _extract_text_from_content(content: Any) -> str:
"""Extract text from message content (handles string/dict/list formats)"""
if isinstance(content, str):
return content
elif isinstance(content, dict):
return content.get("text", "")
elif isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
return item.get("text", "")
return ""
def _find_image_in_context_window(all_messages: list, max_messages: int, session_id: str) -> Optional[str]:
"""
Find most recent image that will remain in context window after adding new messages
Args:
all_messages: Current message history
max_messages: Maximum messages to keep in window
session_id: Session identifier for logging
Returns:
Path to image file or None
"""
# Calculate which messages will remain after adding current message + assistant response
if len(all_messages) >= max_messages - 1:
# Keep only messages that will survive: drop oldest 2 positions
messages_that_will_remain = all_messages[2:] if len(all_messages) >= 2 else all_messages
else:
messages_that_will_remain = all_messages
# Find most recent image in messages that will remain
for msg in reversed(messages_that_will_remain):
img_path = msg.get("image_path")
if img_path and os.path.exists(img_path):
logger.info(f"πŸ” Session {session_id[:8]} | Found image in context window")
return img_path
# Check if there was an image that got dropped
if len(all_messages) >= max_messages - 1:
for msg in all_messages[:2]: # Check dropped messages
if msg.get("image_path"):
logger.info(f"πŸ—‘οΈ Session {session_id[:8]} | Image outside context window (will be deleted)")
return None
def _build_history_messages(session_id: str) -> list[Dict[str, Any]]:
"""Build API message history from session files"""
messages = session_manager.get_messages_for_context(session_id)
return [{"role": msg["role"], "content": msg["content"]} for msg in messages]
def _find_most_recent_image(session_id: str) -> Optional[str]:
"""Find most recent image in session history"""
messages = session_manager.get_messages_for_context(session_id)
for msg in reversed(messages):
if msg.get("image_path") and os.path.exists(msg["image_path"]):
return msg["image_path"]
return None
async def respond_stream_hf(message, history, system_message: str, max_tokens: int, temperature: float, top_p: float, session_id: Optional[str] = None, model_choice: str = "Base (General)") -> AsyncGenerator[str, None]:
"""
HF Spaces streaming response using transformers
Args:
message: Current user message (text or dict with text/files)
history: Gradio history
system_message: System prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
session_id: Optional session identifier
model_choice: Model selection string
"""
from .config import HF_BASE_MODEL, HF_FT_MODEL
from .server_hf import run_hf_inference
# Map model choice to model ID
model_id = HF_FT_MODEL if model_choice == "Fine-Tuned (BraTS)" else HF_BASE_MODEL
# Generate session ID if not provided
if session_id is None:
session_id = str(id(history)) if history is not None else "default"
# Log request
model_type = "FT" if model_choice == "Fine-Tuned (BraTS)" else "Base"
has_image = isinstance(message, dict) and message.get("files")
img_status = "with image" if has_image else "text only"
logger.info(f"πŸ€– HF Spaces | Session: {session_id[:8]} | Model: {model_type} | {img_status}")
# Convert Gradio history to HF format
hf_history = []
for item in history:
role = item.get("role", "user")
content = item.get("content", "")
# Handle multimodal content
if isinstance(content, list):
# Already in multimodal format
hf_history.append({"role": role, "content": content})
elif isinstance(content, str):
hf_history.append({"role": role, "content": content})
else:
# Image tuple format from Gradio
hf_history.append({"role": role, "content": content})
# Stream response from HF inference (sync generator in async context)
try:
# Run sync generator in executor to avoid blocking
loop = asyncio.get_event_loop()
gen = run_hf_inference(message, hf_history, system_message, max_tokens, model_id)
# Yield chunks as they come
def get_next(g):
try:
return next(g), False
except StopIteration:
return None, True
while True:
chunk, done = await loop.run_in_executor(None, get_next, gen)
if done:
break
if chunk is not None:
yield chunk
# Small delay to allow UI to update
await asyncio.sleep(0.01)
except Exception as e:
logger.error(f"❌ HF inference error: {e}")
yield f"Error: {str(e)}"
async def respond_stream(message, history, system_message: str, max_tokens: int, temperature: float, top_p: float, session_id: Optional[str] = None, server_url: Optional[str] = None, model_choice: str = "Base (General)") -> AsyncGenerator[str, None]:
"""
Stream VLM response - routes to HF Spaces or local GGUF based on environment
Args:
message: Current user message (text or dict with text/files)
history: Gradio history (used to extract session info)
system_message: System prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
session_id: Optional session identifier
server_url: Optional custom server URL (local only)
model_choice: Model selection string
"""
# Route to appropriate inference method
if IS_HF_SPACE:
async for chunk in respond_stream_hf(message, history, system_message, max_tokens, temperature, top_p, session_id, model_choice):
yield chunk
return
# Local GGUF inference below
# Cleanup expired sessions periodically
cleanup_expired_sessions()
# Generate session ID from history object if not provided
if session_id is None:
session_id = str(id(history)) if history is not None else "default"
# Ensure session exists
if not session_manager.session_exists(session_id):
session_manager.create_session(session_id)
session_manager.update_activity(session_id)
active_count = len(get_active_sessions())
history_msgs = session_manager.get_session_history(session_id)
# Extract text from current message
prompt = _extract_text_from_content(message.get("text") if isinstance(message, dict) else message)
# Handle new image upload
current_image_path = None
has_image = isinstance(message, dict) and message.get("files") and len(message["files"]) > 0
if has_image:
current_image_path = message["files"][0]
# Build conversation history from session files
all_messages = session_manager.get_messages_for_context(session_id)
# Log request summary
img_status = "with image" if has_image else "text only"
logger.info(f"Request | Session: {session_id[:8]} | {img_status} | History: {len(history_msgs)} msgs")
# Determine which image to use (new upload or cached from history)
from .config import HISTORY_EXCHANGES
max_messages = HISTORY_EXCHANGES * 2
if current_image_path:
image_to_use = current_image_path
else:
image_to_use = _find_image_in_context_window(all_messages, max_messages, session_id)
# Build API messages from history
messages = [{"role": msg["role"], "content": msg["content"]} for msg in all_messages]
# Build current message with or without image
if image_to_use:
image_b64, mime_type = await _encode_image(image_to_use)
if image_b64:
messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{image_b64}"}},
{"type": "text", "text": prompt}
]
})
if current_image_path:
logger.info(f"πŸ“€ Sending WITH NEW IMAGE ({mime_type})")
else:
logger.info(f"πŸ“€ Sending WITH IMAGE from history ({mime_type})")
else:
messages.append({"role": "user", "content": prompt})
logger.info(f"πŸ“€ TEXT ONLY (image load failed)")
else:
messages.append({"role": "user", "content": prompt})
logger.info(f"πŸ“€ TEXT ONLY (no image in window)")
# Count images in payload
image_count = sum(1 for msg in messages if isinstance(msg.get("content"), list))
logger.info(f"πŸ“Š Sending: {len(messages)} msgs | {image_count} image(s)")
payload = {
"model": "medgemma",
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stream": True
}
# Use custom server URL if provided, otherwise use default
target_url = server_url if server_url else SERVER_URL
try:
timeout = aiohttp.ClientTimeout(total=REQUEST_TIMEOUT, connect=CONNECT_TIMEOUT)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with SEMAPHORE:
async with session.post(target_url, json=payload) as resp:
if resp.status != 200:
logger.error(f"❌ Server error: {resp.status}")
yield f"Error: {resp.status}"
return
full, token_count = "", 0
# Stream response chunks
async for line in resp.content:
text = line.decode().strip()
if not text.startswith('data: '):
continue
data = text[6:] # Remove 'data: ' prefix
if data == '[DONE]':
break
try:
obj = json.loads(data)
chunk = obj.get('choices', [{}])[0].get('delta', {}).get('content', '')
if chunk:
full += chunk
token_count += 1
yield full
except:
continue # Skip malformed chunks
# Handle empty response (context exceeded)
if token_count == 0:
error_msg = "⚠️ Context size exceeded. Please start a new conversation or reduce history."
logger.warning(f"⚠️ Session {session_id[:8]} | Context exceeded")
yield error_msg
else:
logger.info(f"βœ… COMPLETED | Session: {session_id[:8]} | Tokens: {token_count} | Active: {len(get_active_sessions())}")
# Save messages to session history
session_manager.add_message(session_id, "user", prompt, current_image_path)
session_manager.add_message(session_id, "assistant", full, None)
except Exception as e:
logger.error(f"❌ Exception: {e}")
yield f"Error: {e}"