PhishingTest / llm_client.py
dungeon29's picture
Update llm_client.py
47d69b8 verified
import os
import requests
import subprocess
import tarfile
import stat
import time
import atexit
from huggingface_hub import hf_hub_download
from langchain_core.language_models import LLM
from langchain.chains import RetrievalQA
from langchain_core.prompts import PromptTemplate
from typing import Any, List, Optional, Mapping
# --- Helper to Setup llama-server ---
def setup_llama_binaries():
"""
Download and extract llama-server binary and libs from official releases
"""
# Latest release URL for Linux x64 (b4991 equivalent or newer)
CLI_URL = "https://github.com/ggml-org/llama.cpp/releases/download/b7312/llama-b7312-bin-ubuntu-x64.tar.gz"
LOCAL_TAR = "llama-cli.tar.gz"
BIN_DIR = "./llama_bin"
SERVER_BIN = os.path.join(BIN_DIR, "bin/llama-server") # Look for server binary
if os.path.exists(SERVER_BIN):
return SERVER_BIN, BIN_DIR
try:
print("⬇️ Downloading llama.cpp binaries...")
response = requests.get(CLI_URL, stream=True)
if response.status_code == 200:
with open(LOCAL_TAR, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("📦 Extracting binaries...")
os.makedirs(BIN_DIR, exist_ok=True)
with tarfile.open(LOCAL_TAR, "r:gz") as tar:
tar.extractall(path=BIN_DIR)
# Locate llama-server
found_bin = None
for root, dirs, files in os.walk(BIN_DIR):
if "llama-server" in files:
found_bin = os.path.join(root, "llama-server")
break
if not found_bin:
print("❌ Could not find llama-server in extracted files.")
return None, None
# Make executable
st = os.stat(found_bin)
os.chmod(found_bin, st.st_mode | stat.S_IEXEC)
print(f"✅ llama-server binary ready at {found_bin}!")
return found_bin, BIN_DIR
else:
print(f"❌ Failed to download binaries: {response.status_code}")
return None, None
except Exception as e:
print(f"❌ Error setting up llama-server: {e}")
return None, None
# --- Custom LangChain LLM Wrapper for Hybrid Approach ---
class HybridLLM(LLM):
groq_client: Any = None
groq_model: str = "qwen/qwen3-32b"
api_url: str = ""
local_server_url: str = "http://localhost:8080"
@property
def _llm_type(self) -> str:
return "hybrid_llm"
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
# 1. Try Groq API (Highest Priority)
if self.groq_client:
try:
print("⚡ Using Groq API...")
stop_seq = (stop or []) + ["<|im_end|>", "Input:", "Context:"]
chat_completion = self.groq_client.chat.completions.create(
messages=[
{"role": "user", "content": prompt}
],
model=self.groq_model,
temperature=0.3,
max_tokens=1024,
stop=stop_seq
)
return chat_completion.choices[0].message.content
except Exception as e:
print(f"⚠️ Groq API Failed: {e}")
# Continue to next fallback
# 2. Try Colab API
if self.api_url:
try:
print(f"🌐 Calling Colab API: {self.api_url}")
response = requests.post(
f"{self.api_url}/generate",
json={"prompt": prompt, "max_tokens": 512},
timeout=30
)
if response.status_code == 200:
return response.json()["response"]
else:
print(f"⚠️ API Error {response.status_code}: {response.text}")
except Exception as e:
print(f"⚠️ API Connection Failed: {e}")
# 3. Fallback to Local Server
print("💻 Using Local llama-server Fallback...")
try:
# OpenAI-compatible completion endpoint
payload = {
"prompt": prompt,
"n_predict": 1024,
"temperature": 0.3,
"stop": (stop or []) + ["<|im_end|>", "Input:", "Context:"]
}
response = requests.post(
f"{self.local_server_url}/completion",
json=payload,
timeout=300
)
if response.status_code == 200:
return response.json()["content"]
else:
return f"❌ Local Server Error: {response.text}"
except Exception as e:
return f"❌ Local Inference Failed: {e}"
return "❌ Error: No working LLM available."
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {
"groq_enabled": self.groq_client is not None,
"groq_model": self.groq_model,
"api_url": self.api_url,
"local_server_url": self.local_server_url
}
class LLMClient:
def __init__(self, vector_store=None):
"""
Initialize Hybrid LLM Client with Persistent Server
"""
self.vector_store = vector_store
self.api_url = os.environ.get("COLAB_API_URL", "")
self.server_process = None
self.server_port = 8080
self.groq_client = None
# 1. Setup Groq Client
groq_api_key = os.environ.get("GROQ_API_KEY")
self.groq_model = "qwen/qwen3-32b"
if groq_api_key:
try:
from groq import Groq
print(f"⚡ Initializing Native Groq Client ({self.groq_model})...")
self.groq_client = Groq(api_key=groq_api_key)
print("✅ Groq Client ready.")
except Exception as e:
print(f"⚠️ Groq Init Failed: {e}")
# 2. Setup Local Fallback (Always setup as requested)
try:
# Setup Binary
self.server_bin, self.lib_path = setup_llama_binaries()
# Download Model (Qwen3-0.6B)
print("� Loading Local Qwen3-0.6B (GGUF)...")
model_repo = "Qwen/Qwen3-0.6B-GGUF"
filename = "Qwen3-0.6B-Q8_0.gguf"
self.model_path = hf_hub_download(
repo_id=model_repo,
filename=filename
)
print(f"✅ Model downloaded to: {self.model_path}")
# Start Server
self.start_local_server()
except Exception as e:
print(f"⚠️ Could not setup local fallback: {e}")
# Create Hybrid LangChain Wrapper
self.llm = HybridLLM(
groq_client=self.groq_client,
groq_model=self.groq_model,
api_url=self.api_url,
local_server_url=f"http://localhost:{self.server_port}"
)
def start_local_server(self):
"""Start llama-server in background"""
if not self.server_bin or not self.model_path:
return
print("🚀 Starting llama-server...")
# Setup Env
env = os.environ.copy()
lib_paths = [os.path.dirname(self.server_bin)]
lib_subdir = os.path.join(self.lib_path, "lib")
if os.path.exists(lib_subdir):
lib_paths.append(lib_subdir)
env["LD_LIBRARY_PATH"] = ":".join(lib_paths) + ":" + env.get("LD_LIBRARY_PATH", "")
cmd = [
self.server_bin,
"-m", self.model_path,
"--port", str(self.server_port),
"-c", "2048",
"--host", "0.0.0.0" # Bind to all interfaces for container
]
# Launch process
self.server_process = subprocess.Popen(
cmd,
stdout=subprocess.DEVNULL, # Suppress noisy logs
stderr=subprocess.DEVNULL,
env=env
)
# Register cleanup
atexit.register(self.stop_server)
# Wait for server to be ready
print("⏳ Waiting for server to be ready...")
for _ in range(20): # Wait up to 20s
try:
requests.get(f"http://localhost:{self.server_port}/health", timeout=1)
print("✅ llama-server is ready!")
return
except:
time.sleep(1)
print("⚠️ Server start timed out (but might still be loading).")
def stop_server(self):
"""Kill the server process"""
if self.server_process:
print("🛑 Stopping llama-server...")
self.server_process.terminate()
self.server_process = None
def analyze(self, text, context_chunks=None):
"""
Analyze text using LangChain RetrievalQA
"""
if not self.vector_store:
return "❌ Vector Store not initialized."
# Custom Prompt Template
# Custom Prompt Template - Stricter Format
template = """<|im_start|>system
You are CyberGuard - an AI specialized in Phishing Detection.
Task: Analyze the provided URL and HTML snippet to classify the website as 'PHISHING' or 'BENIGN'.
Check specifically for BRAND IMPERSONATION (e.g. Facebook, Google, Banks).
Classification Rules:
- PHISHING: Typosquatting URLs (e.g., paypa1.com), hidden login forms, obfuscated javascript, mismatched branding vs URL.
- BENIGN: Legitimate website, clean code, URL matches the content/brand.
RETURN THE RESULT IN THE EXACT FOLLOWING FORMAT (NO PREAMBLE):
CLASSIFICATION: [PHISHING or BENIGN]
CONFIDENCE SCORE: [0-100]%
EXPLANATION: [Write 3-4 concise sentences explaining the main reason]
<|im_end|>
<|im_start|>user
Context from knowledge base:
{context}
Input to analyze:
{question}
<|im_end|>
<|im_start|>assistant
"""
PROMPT = PromptTemplate(
template=template,
input_variables=["context", "question"]
)
# Create QA Chain
qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(
search_type="mmr",
search_kwargs={"k": 3, "fetch_k": 10}
),
chain_type_kwargs={"prompt": PROMPT}
)
try:
print("🤖 Generating response...")
response = qa_chain.invoke(text)
return response['result']
except Exception as e:
return f"❌ Error: {str(e)}"