Spaces:
Paused
Paused
| import os | |
| import requests | |
| import subprocess | |
| import tarfile | |
| import stat | |
| import time | |
| import atexit | |
| from huggingface_hub import hf_hub_download | |
| from langchain_core.language_models import LLM | |
| from langchain.chains import RetrievalQA | |
| from langchain_core.prompts import PromptTemplate | |
| from typing import Any, List, Optional, Mapping | |
| # --- Helper to Setup llama-server --- | |
| def setup_llama_binaries(): | |
| """ | |
| Download and extract llama-server binary and libs from official releases | |
| """ | |
| # Latest release URL for Linux x64 (b4991 equivalent or newer) | |
| CLI_URL = "https://github.com/ggml-org/llama.cpp/releases/download/b7312/llama-b7312-bin-ubuntu-x64.tar.gz" | |
| LOCAL_TAR = "llama-cli.tar.gz" | |
| BIN_DIR = "./llama_bin" | |
| SERVER_BIN = os.path.join(BIN_DIR, "bin/llama-server") # Look for server binary | |
| if os.path.exists(SERVER_BIN): | |
| return SERVER_BIN, BIN_DIR | |
| try: | |
| print("⬇️ Downloading llama.cpp binaries...") | |
| response = requests.get(CLI_URL, stream=True) | |
| if response.status_code == 200: | |
| with open(LOCAL_TAR, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print("📦 Extracting binaries...") | |
| os.makedirs(BIN_DIR, exist_ok=True) | |
| with tarfile.open(LOCAL_TAR, "r:gz") as tar: | |
| tar.extractall(path=BIN_DIR) | |
| # Locate llama-server | |
| found_bin = None | |
| for root, dirs, files in os.walk(BIN_DIR): | |
| if "llama-server" in files: | |
| found_bin = os.path.join(root, "llama-server") | |
| break | |
| if not found_bin: | |
| print("❌ Could not find llama-server in extracted files.") | |
| return None, None | |
| # Make executable | |
| st = os.stat(found_bin) | |
| os.chmod(found_bin, st.st_mode | stat.S_IEXEC) | |
| print(f"✅ llama-server binary ready at {found_bin}!") | |
| return found_bin, BIN_DIR | |
| else: | |
| print(f"❌ Failed to download binaries: {response.status_code}") | |
| return None, None | |
| except Exception as e: | |
| print(f"❌ Error setting up llama-server: {e}") | |
| return None, None | |
| # --- Custom LangChain LLM Wrapper for Hybrid Approach --- | |
| class HybridLLM(LLM): | |
| groq_client: Any = None | |
| groq_model: str = "qwen/qwen3-32b" | |
| api_url: str = "" | |
| local_server_url: str = "http://localhost:8080" | |
| def _llm_type(self) -> str: | |
| return "hybrid_llm" | |
| def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str: | |
| # 1. Try Groq API (Highest Priority) | |
| if self.groq_client: | |
| try: | |
| print("⚡ Using Groq API...") | |
| stop_seq = (stop or []) + ["<|im_end|>", "Input:", "Context:"] | |
| chat_completion = self.groq_client.chat.completions.create( | |
| messages=[ | |
| {"role": "user", "content": prompt} | |
| ], | |
| model=self.groq_model, | |
| temperature=0.3, | |
| max_tokens=1024, | |
| stop=stop_seq | |
| ) | |
| return chat_completion.choices[0].message.content | |
| except Exception as e: | |
| print(f"⚠️ Groq API Failed: {e}") | |
| # Continue to next fallback | |
| # 2. Try Colab API | |
| if self.api_url: | |
| try: | |
| print(f"🌐 Calling Colab API: {self.api_url}") | |
| response = requests.post( | |
| f"{self.api_url}/generate", | |
| json={"prompt": prompt, "max_tokens": 512}, | |
| timeout=30 | |
| ) | |
| if response.status_code == 200: | |
| return response.json()["response"] | |
| else: | |
| print(f"⚠️ API Error {response.status_code}: {response.text}") | |
| except Exception as e: | |
| print(f"⚠️ API Connection Failed: {e}") | |
| # 3. Fallback to Local Server | |
| print("💻 Using Local llama-server Fallback...") | |
| try: | |
| # OpenAI-compatible completion endpoint | |
| payload = { | |
| "prompt": prompt, | |
| "n_predict": 1024, | |
| "temperature": 0.3, | |
| "stop": (stop or []) + ["<|im_end|>", "Input:", "Context:"] | |
| } | |
| response = requests.post( | |
| f"{self.local_server_url}/completion", | |
| json=payload, | |
| timeout=300 | |
| ) | |
| if response.status_code == 200: | |
| return response.json()["content"] | |
| else: | |
| return f"❌ Local Server Error: {response.text}" | |
| except Exception as e: | |
| return f"❌ Local Inference Failed: {e}" | |
| return "❌ Error: No working LLM available." | |
| def _identifying_params(self) -> Mapping[str, Any]: | |
| return { | |
| "groq_enabled": self.groq_client is not None, | |
| "groq_model": self.groq_model, | |
| "api_url": self.api_url, | |
| "local_server_url": self.local_server_url | |
| } | |
| class LLMClient: | |
| def __init__(self, vector_store=None): | |
| """ | |
| Initialize Hybrid LLM Client with Persistent Server | |
| """ | |
| self.vector_store = vector_store | |
| self.api_url = os.environ.get("COLAB_API_URL", "") | |
| self.server_process = None | |
| self.server_port = 8080 | |
| self.groq_client = None | |
| # 1. Setup Groq Client | |
| groq_api_key = os.environ.get("GROQ_API_KEY") | |
| self.groq_model = "qwen/qwen3-32b" | |
| if groq_api_key: | |
| try: | |
| from groq import Groq | |
| print(f"⚡ Initializing Native Groq Client ({self.groq_model})...") | |
| self.groq_client = Groq(api_key=groq_api_key) | |
| print("✅ Groq Client ready.") | |
| except Exception as e: | |
| print(f"⚠️ Groq Init Failed: {e}") | |
| # 2. Setup Local Fallback (Always setup as requested) | |
| try: | |
| # Setup Binary | |
| self.server_bin, self.lib_path = setup_llama_binaries() | |
| # Download Model (Qwen3-0.6B) | |
| print("� Loading Local Qwen3-0.6B (GGUF)...") | |
| model_repo = "Qwen/Qwen3-0.6B-GGUF" | |
| filename = "Qwen3-0.6B-Q8_0.gguf" | |
| self.model_path = hf_hub_download( | |
| repo_id=model_repo, | |
| filename=filename | |
| ) | |
| print(f"✅ Model downloaded to: {self.model_path}") | |
| # Start Server | |
| self.start_local_server() | |
| except Exception as e: | |
| print(f"⚠️ Could not setup local fallback: {e}") | |
| # Create Hybrid LangChain Wrapper | |
| self.llm = HybridLLM( | |
| groq_client=self.groq_client, | |
| groq_model=self.groq_model, | |
| api_url=self.api_url, | |
| local_server_url=f"http://localhost:{self.server_port}" | |
| ) | |
| def start_local_server(self): | |
| """Start llama-server in background""" | |
| if not self.server_bin or not self.model_path: | |
| return | |
| print("🚀 Starting llama-server...") | |
| # Setup Env | |
| env = os.environ.copy() | |
| lib_paths = [os.path.dirname(self.server_bin)] | |
| lib_subdir = os.path.join(self.lib_path, "lib") | |
| if os.path.exists(lib_subdir): | |
| lib_paths.append(lib_subdir) | |
| env["LD_LIBRARY_PATH"] = ":".join(lib_paths) + ":" + env.get("LD_LIBRARY_PATH", "") | |
| cmd = [ | |
| self.server_bin, | |
| "-m", self.model_path, | |
| "--port", str(self.server_port), | |
| "-c", "2048", | |
| "--host", "0.0.0.0" # Bind to all interfaces for container | |
| ] | |
| # Launch process | |
| self.server_process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.DEVNULL, # Suppress noisy logs | |
| stderr=subprocess.DEVNULL, | |
| env=env | |
| ) | |
| # Register cleanup | |
| atexit.register(self.stop_server) | |
| # Wait for server to be ready | |
| print("⏳ Waiting for server to be ready...") | |
| for _ in range(20): # Wait up to 20s | |
| try: | |
| requests.get(f"http://localhost:{self.server_port}/health", timeout=1) | |
| print("✅ llama-server is ready!") | |
| return | |
| except: | |
| time.sleep(1) | |
| print("⚠️ Server start timed out (but might still be loading).") | |
| def stop_server(self): | |
| """Kill the server process""" | |
| if self.server_process: | |
| print("🛑 Stopping llama-server...") | |
| self.server_process.terminate() | |
| self.server_process = None | |
| def analyze(self, text, context_chunks=None): | |
| """ | |
| Analyze text using LangChain RetrievalQA | |
| """ | |
| if not self.vector_store: | |
| return "❌ Vector Store not initialized." | |
| # Custom Prompt Template | |
| # Custom Prompt Template - Stricter Format | |
| template = """<|im_start|>system | |
| You are CyberGuard - an AI specialized in Phishing Detection. | |
| Task: Analyze the provided URL and HTML snippet to classify the website as 'PHISHING' or 'BENIGN'. | |
| Check specifically for BRAND IMPERSONATION (e.g. Facebook, Google, Banks). | |
| Classification Rules: | |
| - PHISHING: Typosquatting URLs (e.g., paypa1.com), hidden login forms, obfuscated javascript, mismatched branding vs URL. | |
| - BENIGN: Legitimate website, clean code, URL matches the content/brand. | |
| RETURN THE RESULT IN THE EXACT FOLLOWING FORMAT (NO PREAMBLE): | |
| CLASSIFICATION: [PHISHING or BENIGN] | |
| CONFIDENCE SCORE: [0-100]% | |
| EXPLANATION: [Write 3-4 concise sentences explaining the main reason] | |
| <|im_end|> | |
| <|im_start|>user | |
| Context from knowledge base: | |
| {context} | |
| Input to analyze: | |
| {question} | |
| <|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| PROMPT = PromptTemplate( | |
| template=template, | |
| input_variables=["context", "question"] | |
| ) | |
| # Create QA Chain | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=self.llm, | |
| chain_type="stuff", | |
| retriever=self.vector_store.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={"k": 3, "fetch_k": 10} | |
| ), | |
| chain_type_kwargs={"prompt": PROMPT} | |
| ) | |
| try: | |
| print("🤖 Generating response...") | |
| response = qa_chain.invoke(text) | |
| return response['result'] | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}" | |