dungeon29 commited on
Commit
e0883f0
·
verified ·
1 Parent(s): eeb2022

Update llm_client.py

Browse files
Files changed (1) hide show
  1. llm_client.py +130 -208
llm_client.py CHANGED
@@ -1,225 +1,147 @@
1
  import os
2
- import requests
3
- import subprocess
4
- import tarfile
5
- import stat
6
- from huggingface_hub import hf_hub_download
7
- from langchain.llms.base import LLM
8
- from langchain.chains import RetrievalQA
9
- from langchain_core.prompts import PromptTemplate
10
- from typing import Any, List, Optional, Mapping
11
 
12
- # --- Helper to Setup llama-cli ---
13
- def setup_llama_cli():
14
- """
15
- Download and extract llama-cli binary and libs from official releases
16
- """
17
- # Latest release URL for Linux x64 (b4991 equivalent or newer)
18
- # Using the one found: b7312
19
- CLI_URL = "https://github.com/ggml-org/llama.cpp/releases/download/b7312/llama-b7312-bin-ubuntu-x64.tar.gz"
20
- LOCAL_TAR = "llama-cli.tar.gz"
21
- BIN_DIR = "./llama_bin" # Extract to a subdirectory
22
- CLI_BIN = os.path.join(BIN_DIR, "bin/llama-cli") # Standard structure usually has bin/
23
-
24
- if os.path.exists(CLI_BIN):
25
- return CLI_BIN, BIN_DIR
26
-
27
- try:
28
- print("⬇️ Downloading llama-cli binary...")
29
- response = requests.get(CLI_URL, stream=True)
30
- if response.status_code == 200:
31
- with open(LOCAL_TAR, 'wb') as f:
32
- for chunk in response.iter_content(chunk_size=8192):
33
- f.write(chunk)
34
-
35
- print("📦 Extracting llama-cli...")
36
- # Create dir
37
- os.makedirs(BIN_DIR, exist_ok=True)
38
-
39
- with tarfile.open(LOCAL_TAR, "r:gz") as tar:
40
- tar.extractall(path=BIN_DIR)
41
-
42
- # Locate the binary (it might be in bin/ or root of tar)
43
- # We search for it
44
- found_bin = None
45
- for root, dirs, files in os.walk(BIN_DIR):
46
- if "llama-cli" in files:
47
- found_bin = os.path.join(root, "llama-cli")
48
- break
49
-
50
- if not found_bin:
51
- print("❌ Could not find llama-cli in extracted files.")
52
- return None, None
53
 
54
- # Make executable
55
- st = os.stat(found_bin)
56
- os.chmod(found_bin, st.st_mode | stat.S_IEXEC)
57
- print(f"✅ llama-cli binary ready at {found_bin}!")
58
- return found_bin, BIN_DIR
59
- else:
60
- print(f"❌ Failed to download binary: {response.status_code}")
61
- return None, None
62
- except Exception as e:
63
- print(f"❌ Error setting up llama-cli: {e}")
64
- return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # --- Custom LangChain LLM Wrapper for Hybrid Approach ---
67
- class HybridLLM(LLM):
68
- api_url: str = ""
69
- model_path: str = ""
70
- cli_path: str = ""
71
- lib_path: str = "" # Path to folder containing .so files
72
-
73
- @property
74
- def _llm_type(self) -> str:
75
- return "hybrid_llm"
76
 
77
- def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
78
- # 1. Try Colab API first
79
- if self.api_url:
80
- try:
81
- print(f"🌐 Calling Colab API: {self.api_url}")
82
- response = requests.post(
83
- f"{self.api_url}/generate",
84
- json={"prompt": prompt, "max_tokens": 512},
85
- timeout=30
86
- )
87
- if response.status_code == 200:
88
- return response.json()["response"]
89
- else:
90
- print(f"⚠️ API Error {response.status_code}: {response.text}")
91
- except Exception as e:
92
- print(f"⚠️ API Connection Failed: {e}")
93
 
94
- # 2. Fallback to Local llama-cli
95
- if self.model_path and self.cli_path and os.path.exists(self.cli_path):
96
- print("💻 Using Local llama-cli Fallback...")
97
  try:
98
- # Construct command
99
- cmd = [
100
- self.cli_path,
101
- "-m", self.model_path,
102
- "-p", prompt,
103
- "-n", "512",
104
- "--temp", "0.7",
105
- "--no-display-prompt", # Don't echo prompt
106
- "-c", "2048" # Context size
107
- ]
108
-
109
- # Setup Environment with LD_LIBRARY_PATH
110
- env = os.environ.copy()
111
- # Add the directory containing the binary (and likely libs) to LD_LIBRARY_PATH
112
- # Also check 'lib' subdir if it exists
113
- lib_paths = [os.path.dirname(self.cli_path)]
114
- lib_subdir = os.path.join(self.lib_path, "lib")
115
- if os.path.exists(lib_subdir):
116
- lib_paths.append(lib_subdir)
117
-
118
- env["LD_LIBRARY_PATH"] = ":".join(lib_paths) + ":" + env.get("LD_LIBRARY_PATH", "")
119
-
120
- # Run binary
121
- result = subprocess.run(
122
- cmd,
123
- capture_output=True,
124
- text=True,
125
- encoding='utf-8',
126
- errors='replace',
127
- env=env
128
- )
129
-
130
- if result.returncode == 0:
131
- return result.stdout.strip()
132
- else:
133
- return f"❌ llama-cli Error: {result.stderr}"
134
  except Exception as e:
135
- return f"❌ Local Inference Failed: {e}"
136
-
137
- return " Error: No working LLM available (API failed and no local model)."
138
-
139
- @property
140
- def _identifying_params(self) -> Mapping[str, Any]:
141
- return {"api_url": self.api_url, "model_path": self.model_path}
142
 
143
- class LLMClient:
144
- def __init__(self, vector_store=None):
145
- """
146
- Initialize Hybrid LLM Client with Binary Wrapper
147
- """
148
- self.vector_store = vector_store
149
- self.api_url = os.environ.get("COLAB_API_URL", "")
150
- self.model_path = None
151
- self.cli_path = None
152
- self.lib_path = None
153
 
154
- # Setup Local Fallback
 
 
 
 
 
 
 
 
 
155
  try:
156
- # 1. Setup Binary
157
- self.cli_path, self.lib_path = setup_llama_cli()
158
-
159
- # 2. Download Model (Qwen3-0.6B)
160
- print("📂 Loading Local Qwen3-0.6B (GGUF)...")
161
- model_repo = "Qwen/Qwen3-0.6B-GGUF"
162
- filename = "Qwen3-0.6B-Q8_0.gguf"
163
-
164
- self.model_path = hf_hub_download(
165
- repo_id=model_repo,
166
- filename=filename
167
- )
168
- print(f" Model downloaded to: {self.model_path}")
169
-
 
 
 
 
 
 
 
 
170
  except Exception as e:
171
- print(f"⚠️ Could not setup local fallback: {e}")
 
 
172
 
173
- # Create Hybrid LangChain Wrapper
174
- self.llm = HybridLLM(
175
- api_url=self.api_url,
176
- model_path=self.model_path,
177
- cli_path=self.cli_path,
178
- lib_path=self.lib_path
179
- )
 
 
 
 
180
 
181
- def analyze(self, text, context_chunks=None):
182
- """
183
- Analyze text using LangChain RetrievalQA
184
- """
185
  if not self.vector_store:
186
- return "❌ Vector Store not initialized."
187
-
188
- # Custom Prompt Template
189
- template = """<|im_start|>system
190
- You are a cybersecurity expert. Task: Determine whether the input is 'PHISHING' or 'BENIGN' (Safe).
191
- Respond in the following format:
192
- LABEL: [PHISHING or BENIGN]
193
- EXPLANATION: [A brief Vietnamese explanation]
194
-
195
- Context:
196
- {context}
197
- <|im_end|>
198
- <|im_start|>user
199
- Input:
200
- {question}
201
-
202
- Short Analysis:
203
- <|im_end|>
204
- <|im_start|>assistant
205
- """
206
-
207
- PROMPT = PromptTemplate(
208
- template=template,
209
- input_variables=["context", "question"]
210
- )
211
-
212
- # Create QA Chain
213
- qa_chain = RetrievalQA.from_chain_type(
214
- llm=self.llm,
215
- chain_type="stuff",
216
- retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),
217
- chain_type_kwargs={"prompt": PROMPT}
218
- )
219
-
220
  try:
221
- print("🤖 Generating response...")
222
- response = qa_chain.invoke(text)
223
- return response['result']
224
  except Exception as e:
225
- return f" Error: {str(e)}"
 
 
 
1
  import os
2
+ import glob
3
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
4
+ from langchain_qdrant import Qdrant
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
7
+ from qdrant_client import QdrantClient
 
 
 
8
 
9
+ class RAGEngine:
10
+ def __init__(self, knowledge_base_dir="./knowledge_base"):
11
+ self.knowledge_base_dir = knowledge_base_dir
12
+
13
+ # Initialize Embeddings
14
+ self.embedding_fn = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
15
+
16
+ # Qdrant Cloud Configuration
17
+ # Prioritize Env Vars, fallback to Hardcoded (User provided)
18
+ self.qdrant_url = os.environ.get("QDRANT_URL") or "https://abd29675-7fb9-4d95-8941-e6130b09bf7f.us-east4-0.gcp.cloud.qdrant.io"
19
+ self.qdrant_api_key = os.environ.get("QDRANT_API_KEY") or "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.L0aAAAbxRypLfBeGCtFr2xX06iveGb76NrA3BPJQiNM"
20
+ self.collection_name = "phishing_knowledge"
21
+
22
+ if not self.qdrant_url or not self.qdrant_api_key:
23
+ print("⚠️ QDRANT_URL or QDRANT_API_KEY not set. RAG will not function correctly.")
24
+ self.vector_store = None
25
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ print(f"☁️ Connecting to Qdrant Cloud: {self.qdrant_url}...")
28
+
29
+ # Initialize Qdrant Client
30
+ self.client = QdrantClient(
31
+ url=self.qdrant_url,
32
+ api_key=self.qdrant_api_key
33
+ )
34
+
35
+ # Initialize Vector Store Wrapper
36
+ self.vector_store = Qdrant(
37
+ client=self.client,
38
+ collection_name=self.collection_name,
39
+ embeddings=self.embedding_fn
40
+ )
41
+
42
+ # Check if collection exists/is empty and build if needed
43
+ try:
44
+ count = self.client.count(collection_name=self.collection_name).count
45
+ if count == 0:
46
+ self._build_index()
47
+ else:
48
+ print(f"✅ Qdrant Collection '{self.collection_name}' ready with {count} vectors.")
49
+ except Exception as e:
50
+ print(f"⚠️ Collection check failed (might not exist): {e}")
51
+ self._build_index()
52
 
53
+ def _build_index(self):
54
+ """Load documents and build index"""
55
+ print("🔄 Building Knowledge Base Index on Qdrant Cloud...")
56
+
57
+ documents = self._load_documents()
58
+ if not documents:
59
+ print("⚠️ No documents found to index.")
60
+ return
 
 
61
 
62
+ # Split documents
63
+ text_splitter = RecursiveCharacterTextSplitter(
64
+ chunk_size=500,
65
+ chunk_overlap=50,
66
+ separators=["\n\n", "\n", " ", ""]
67
+ )
68
+ chunks = text_splitter.split_documents(documents)
 
 
 
 
 
 
 
 
 
69
 
70
+ if chunks:
71
+ # Add to vector store (Qdrant handles persistence automatically)
 
72
  try:
73
+ self.vector_store.add_documents(chunks)
74
+ print(f"✅ Indexed {len(chunks)} chunks to Qdrant Cloud.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  except Exception as e:
76
+ print(f"❌ Error indexing to Qdrant: {e}")
77
+ else:
78
+ print("⚠️ No chunks created.")
 
 
 
 
79
 
80
+ def _load_documents(self):
81
+ """Load documents from directory or fallback file"""
82
+ documents = []
 
 
 
 
 
 
 
83
 
84
+ # Check for directory or fallback file
85
+ target_path = self.knowledge_base_dir
86
+ if not os.path.exists(target_path):
87
+ if os.path.exists("knowledge_base.txt"):
88
+ target_path = "knowledge_base.txt"
89
+ print("⚠️ Using fallback 'knowledge_base.txt' in root.")
90
+ else:
91
+ print(f"❌ Knowledge base not found at {target_path}")
92
+ return []
93
+
94
  try:
95
+ if os.path.isfile(target_path):
96
+ # Load single file
97
+ if target_path.endswith(".pdf"):
98
+ loader = PyPDFLoader(target_path)
99
+ else:
100
+ loader = TextLoader(target_path, encoding="utf-8")
101
+ documents.extend(loader.load())
102
+ else:
103
+ # Load directory
104
+ loaders = [
105
+ DirectoryLoader(target_path, glob="**/*.txt", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
106
+ DirectoryLoader(target_path, glob="**/*.md", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
107
+ DirectoryLoader(target_path, glob="**/*.pdf", loader_cls=PyPDFLoader),
108
+ ]
109
+
110
+ for loader in loaders:
111
+ try:
112
+ docs = loader.load()
113
+ documents.extend(docs)
114
+ except Exception as e:
115
+ print(f"⚠️ Error loading with {loader}: {e}")
116
+
117
  except Exception as e:
118
+ print(f" Error loading documents: {e}")
119
+
120
+ return documents
121
 
122
+ def refresh_knowledge_base(self):
123
+ """Force rebuild of the index"""
124
+ print("♻️ Refreshing Knowledge Base...")
125
+ if self.client:
126
+ try:
127
+ self.client.delete_collection(self.collection_name)
128
+ self._build_index()
129
+ return "✅ Knowledge Base Refreshed on Cloud!"
130
+ except Exception as e:
131
+ return f"❌ Error refreshing: {e}"
132
+ return "❌ Qdrant Client not initialized."
133
 
134
+ def retrieve(self, query, n_results=3):
135
+ """Retrieve relevant context"""
 
 
136
  if not self.vector_store:
137
+ return []
138
+
139
+ # Search
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  try:
141
+ results = self.vector_store.similarity_search(query, k=n_results)
142
+ if results:
143
+ return [doc.page_content for doc in results]
144
  except Exception as e:
145
+ print(f"⚠️ Retrieval Error: {e}")
146
+
147
+ return []