| |
|
| | """ |
| | CyberForge ML Inference Module |
| | Backend integration for mlService.js |
| | """ |
| |
|
| | import json |
| | import time |
| | import joblib |
| | import numpy as np |
| | from pathlib import Path |
| | from typing import Dict, List, Any, Optional |
| |
|
| | class CyberForgeInference: |
| | """ |
| | ML inference service for CyberForge backend. |
| | Compatible with mlService.js API contract. |
| | """ |
| |
|
| | def __init__(self, models_dir: str): |
| | self.models_dir = Path(models_dir) |
| | self.loaded_models = {} |
| | self.manifest = self._load_manifest() |
| |
|
| | def _load_manifest(self) -> Dict: |
| | manifest_path = self.models_dir / "manifest.json" |
| | if manifest_path.exists(): |
| | with open(manifest_path) as f: |
| | return json.load(f) |
| | return {"models": {}} |
| |
|
| | def load_model(self, model_name: str) -> bool: |
| | """Load a model into memory""" |
| | if model_name in self.loaded_models: |
| | return True |
| |
|
| | model_dir = self.models_dir / model_name |
| | model_path = model_dir / "model.pkl" |
| | scaler_path = model_dir / "scaler.pkl" |
| |
|
| | if not model_path.exists(): |
| | return False |
| |
|
| | self.loaded_models[model_name] = { |
| | "model": joblib.load(model_path), |
| | "scaler": joblib.load(scaler_path) if scaler_path.exists() else None |
| | } |
| | return True |
| |
|
| | def predict(self, model_name: str, features: Dict) -> Dict: |
| | """ |
| | Make a prediction. |
| | |
| | Args: |
| | model_name: Name of the model to use |
| | features: Feature dictionary |
| | |
| | Returns: |
| | Response matching mlService.js contract |
| | """ |
| | if not self.load_model(model_name): |
| | return {"error": f"Model not found: {model_name}"} |
| |
|
| | model_data = self.loaded_models[model_name] |
| | model = model_data["model"] |
| | scaler = model_data["scaler"] |
| |
|
| | |
| | X = np.array([list(features.values())]) |
| |
|
| | |
| | if scaler: |
| | X = scaler.transform(X) |
| |
|
| | |
| | start_time = time.time() |
| | prediction = int(model.predict(X)[0]) |
| | inference_time = (time.time() - start_time) * 1000 |
| |
|
| | |
| | confidence = 0.5 |
| | if hasattr(model, "predict_proba"): |
| | proba = model.predict_proba(X)[0] |
| | confidence = float(max(proba)) |
| |
|
| | |
| | risk_level = ( |
| | "critical" if confidence >= 0.9 else |
| | "high" if confidence >= 0.7 else |
| | "medium" if confidence >= 0.5 else |
| | "low" if confidence >= 0.3 else "info" |
| | ) |
| |
|
| | return { |
| | "prediction": prediction, |
| | "confidence": confidence, |
| | "risk_level": risk_level, |
| | "model_name": model_name, |
| | "model_version": "1.0.0", |
| | "inference_time_ms": inference_time |
| | } |
| |
|
| | def batch_predict(self, model_name: str, features_list: List[Dict]) -> List[Dict]: |
| | """Batch predictions""" |
| | return [self.predict(model_name, f) for f in features_list] |
| |
|
| | def list_models(self) -> List[str]: |
| | """List available models""" |
| | return list(self.manifest.get("models", {}).keys()) |
| |
|
| | def get_model_info(self, model_name: str) -> Dict: |
| | """Get model information""" |
| | return self.manifest.get("models", {}).get(model_name, {}) |
| |
|
| |
|
| | |
| | def create_api(models_dir: str): |
| | """Create FastAPI app for model serving""" |
| | try: |
| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | except ImportError: |
| | return None |
| |
|
| | app = FastAPI(title="CyberForge ML API", version="1.0.0") |
| | inference = CyberForgeInference(models_dir) |
| |
|
| | class PredictRequest(BaseModel): |
| | model_name: str |
| | features: Dict |
| |
|
| | @app.post("/predict") |
| | async def predict(request: PredictRequest): |
| | result = inference.predict(request.model_name, request.features) |
| | if "error" in result: |
| | raise HTTPException(status_code=404, detail=result["error"]) |
| | return result |
| |
|
| | @app.get("/models") |
| | async def list_models(): |
| | return {"models": inference.list_models()} |
| |
|
| | @app.get("/models/{model_name}") |
| | async def get_model_info(model_name: str): |
| | info = inference.get_model_info(model_name) |
| | if not info: |
| | raise HTTPException(status_code=404, detail="Model not found") |
| | return info |
| |
|
| | return app |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import sys |
| | models_dir = sys.argv[1] if len(sys.argv) > 1 else "." |
| |
|
| | inference = CyberForgeInference(models_dir) |
| | print(f"Available models: {inference.list_models()}") |
| |
|