| | """ |
| | SageMaker Multi-Model Endpoint inference script for GLiNER2. |
| | |
| | This script handles model loading and inference for the GLiNER2 Multi-Model Endpoint. |
| | Models are loaded dynamically based on the TargetModel header in the request. |
| | |
| | Key differences from single-model inference: |
| | - model_fn() receives the full path to the model directory (including model name) |
| | - Models are cached automatically by SageMaker MME |
| | - Multiple models can be loaded in memory simultaneously |
| | - LRU eviction when memory is full |
| | """ |
| |
|
| | import json |
| | import os |
| | import sys |
| | import subprocess |
| |
|
| |
|
| | def _ensure_gliner2_installed(): |
| | """ |
| | Ensure gliner2 is installed. Install it dynamically if missing. |
| | |
| | This is a workaround for SageMaker MME where requirements.txt |
| | might not be installed automatically. |
| | """ |
| | try: |
| | import gliner2 |
| |
|
| | print(f"[MME] gliner2 version {gliner2.__version__} already installed") |
| | return True |
| | except ImportError: |
| | print("[MME] gliner2 not found, installing...") |
| | try: |
| | |
| | |
| | subprocess.check_call( |
| | [ |
| | sys.executable, |
| | "-m", |
| | "pip", |
| | "install", |
| | "--quiet", |
| | "--no-cache-dir", |
| | "gliner2==1.0.1", |
| | "transformers>=4.30.0,<4.46.0", |
| | ] |
| | ) |
| | print("[MME] ✓ gliner2 installed successfully") |
| | return True |
| | except subprocess.CalledProcessError as e: |
| | print(f"[MME] ERROR: Failed to install gliner2: {e}") |
| | return False |
| |
|
| |
|
| | |
| | _ensure_gliner2_installed() |
| |
|
| | import torch |
| |
|
| | |
| | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| |
|
| |
|
| | class DummyModel: |
| | """Placeholder model for MME container initialization""" |
| |
|
| | def __call__(self, *args, **kwargs): |
| | raise ValueError("Container model invoked directly. Use TargetModel header.") |
| |
|
| | def extract_entities(self, *args, **kwargs): |
| | raise ValueError("Container model invoked directly. Use TargetModel header.") |
| |
|
| | def classify_text(self, *args, **kwargs): |
| | raise ValueError("Container model invoked directly. Use TargetModel header.") |
| |
|
| | def extract_json(self, *args, **kwargs): |
| | raise ValueError("Container model invoked directly. Use TargetModel header.") |
| |
|
| |
|
| | def model_fn(model_dir): |
| | """ |
| | Load the GLiNER2 model from the model directory. |
| | |
| | For Multi-Model Endpoints, SageMaker passes the full path to the specific |
| | model being loaded, e.g., /opt/ml/models/<model_name>/ |
| | |
| | Args: |
| | model_dir: The directory where model artifacts are extracted |
| | |
| | Returns: |
| | The loaded GLiNER2 model |
| | """ |
| | print(f"[MME] Loading model from: {model_dir}") |
| | try: |
| | print(f"[MME] Contents: {os.listdir(model_dir)}") |
| | except Exception as e: |
| | print(f"[MME] Could not list directory contents: {e}") |
| |
|
| | |
| | try: |
| | from gliner2 import GLiNER2 |
| | except ImportError as e: |
| | print(f"[MME] ERROR: gliner2 import failed: {e}") |
| | print("[MME] Attempting to install gliner2...") |
| | if _ensure_gliner2_installed(): |
| | from gliner2 import GLiNER2 |
| | else: |
| | GLiNER2 = None |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"[MME] Using device: {device}") |
| |
|
| | if torch.cuda.is_available(): |
| | print(f"[MME] GPU: {torch.cuda.get_device_name(0)}") |
| | print(f"[MME] CUDA version: {torch.version.cuda}") |
| | mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 |
| | print(f"[MME] GPU memory: {mem_gb:.2f} GB") |
| |
|
| | |
| | hf_token = os.environ.get("HF_TOKEN") |
| |
|
| | |
| | if os.path.exists(os.path.join(model_dir, "mme_container.txt")): |
| | print("[MME] Container model detected - returning dummy model") |
| | return DummyModel() |
| |
|
| | if GLiNER2 is None: |
| | raise ImportError("gliner2 package required but not found") |
| |
|
| | |
| | if os.path.exists(os.path.join(model_dir, "config.json")): |
| | print("[MME] Loading model from extracted artifacts...") |
| | model = GLiNER2.from_pretrained(model_dir, token=hf_token) |
| | elif os.path.exists(os.path.join(model_dir, "download_at_runtime.txt")): |
| | |
| | print("[MME] Model not in archive, downloading from HuggingFace...") |
| | model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1") |
| | print(f"[MME] Downloading model: {model_name}") |
| | model = GLiNER2.from_pretrained(model_name, token=hf_token) |
| | else: |
| | |
| | model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1") |
| | print(f"[MME] Model directory empty, downloading: {model_name}") |
| | model = GLiNER2.from_pretrained(model_name, token=hf_token) |
| |
|
| | |
| | print(f"[MME] Moving model to {device}...") |
| | model = model.to(device) |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | print("[MME] Converting to fp16...") |
| | model = model.half() |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | torch.cuda.empty_cache() |
| | |
| | torch.cuda.set_per_process_memory_fraction(0.85) |
| | print("[MME] GPU memory optimizations enabled") |
| |
|
| | print(f"[MME] ✓ Model loaded successfully on {device}") |
| | return model |
| |
|
| |
|
| | def input_fn(request_body, request_content_type): |
| | """ |
| | Deserialize and prepare the input data for prediction. |
| | |
| | Args: |
| | request_body: The request body |
| | request_content_type: The content type of the request |
| | |
| | Returns: |
| | Parsed input data as a dictionary |
| | """ |
| | if request_content_type == "application/json": |
| | input_data = json.loads(request_body) |
| | return input_data |
| | else: |
| | raise ValueError(f"Unsupported content type: {request_content_type}") |
| |
|
| |
|
| | def predict_fn(input_data, model): |
| | """ |
| | Run prediction on the input data using the loaded model. |
| | |
| | Args: |
| | input_data: Dictionary containing: |
| | - task: One of 'extract_entities', 'classify_text', or 'extract_json' |
| | - text: Text to process (string) or list of texts (for batch processing) |
| | - schema: Schema for extraction (format depends on task) |
| | - threshold: Optional confidence threshold (default: 0.5) |
| | model: The loaded GLiNER2 model |
| | |
| | Returns: |
| | Task-specific results (single result or list of results for batch) |
| | """ |
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | text = input_data.get("text") |
| | task = input_data.get("task", "extract_entities") |
| | schema = input_data.get("schema") |
| | threshold = input_data.get("threshold", 0.5) |
| |
|
| | if not text: |
| | raise ValueError("'text' field is required") |
| | if not schema: |
| | raise ValueError("'schema' field is required") |
| |
|
| | |
| | is_batch = isinstance(text, list) |
| |
|
| | if is_batch and len(text) == 0: |
| | raise ValueError("'text' list cannot be empty") |
| |
|
| | |
| | with torch.inference_mode(): |
| | if task == "extract_entities": |
| | if is_batch: |
| | if hasattr(model, "batch_extract_entities"): |
| | result = model.batch_extract_entities( |
| | text, schema, threshold=threshold |
| | ) |
| | elif hasattr(model, "batch_predict_entities"): |
| | result = model.batch_predict_entities( |
| | text, schema, threshold=threshold |
| | ) |
| | else: |
| | result = [ |
| | model.extract_entities(t, schema, threshold=threshold) |
| | for t in text |
| | ] |
| | else: |
| | result = model.extract_entities(text, schema, threshold=threshold) |
| | return result |
| |
|
| | elif task == "classify_text": |
| | if is_batch: |
| | if hasattr(model, "batch_classify_text"): |
| | result = model.batch_classify_text( |
| | text, schema, threshold=threshold |
| | ) |
| | else: |
| | result = [ |
| | model.classify_text(t, schema, threshold=threshold) |
| | for t in text |
| | ] |
| | else: |
| | result = model.classify_text(text, schema, threshold=threshold) |
| | return result |
| |
|
| | elif task == "extract_json": |
| | if is_batch: |
| | if hasattr(model, "batch_extract_json"): |
| | result = model.batch_extract_json(text, schema, threshold=threshold) |
| | else: |
| | result = [ |
| | model.extract_json(t, schema, threshold=threshold) for t in text |
| | ] |
| | else: |
| | result = model.extract_json(text, schema, threshold=threshold) |
| | return result |
| |
|
| | else: |
| | raise ValueError( |
| | f"Unsupported task: {task}. " |
| | "Must be one of: extract_entities, classify_text, extract_json" |
| | ) |
| |
|
| |
|
| | def output_fn(prediction, response_content_type): |
| | """ |
| | Serialize the prediction output. |
| | |
| | Args: |
| | prediction: The prediction result |
| | response_content_type: The desired response content type |
| | |
| | Returns: |
| | Serialized prediction |
| | """ |
| | if response_content_type == "application/json": |
| | return json.dumps(prediction) |
| | else: |
| | raise ValueError(f"Unsupported response content type: {response_content_type}") |
| |
|