dungeon29 commited on
Commit
c800b50
Β·
verified Β·
1 Parent(s): 6c05eaf

Using GGUF model

Browse files
Files changed (1) hide show
  1. llm_client.py +103 -55
llm_client.py CHANGED
@@ -1,96 +1,144 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
- from langchain_huggingface import HuggingFacePipeline
 
 
3
  from langchain.chains import RetrievalQA
4
  from langchain_core.prompts import PromptTemplate
5
- import torch
6
 
7
- class LLMClient:
8
- def __init__(self, vector_store=None):
9
- """
10
- Initialize Qwen2.5-3B-Instruct with LangChain
11
- """
12
- print("πŸ”· Loading Qwen2.5-3B-Instruct (LangChain)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- model_name = "Qwen/Qwen2.5-1.5B-Instruct"
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
17
- self.model = AutoModelForCausalLM.from_pretrained(
18
- model_name,
19
- torch_dtype=torch.bfloat16,
20
- device_map="cpu",
21
- low_cpu_mem_usage=True,
22
- trust_remote_code=True
23
- )
24
 
25
- # Create HF Pipeline
26
- pipe = pipeline(
27
- "text-generation",
28
- model=self.model,
29
- tokenizer=self.tokenizer,
30
- max_new_tokens=256,
31
- temperature=0.3,
32
- top_p=0.9,
33
- repetition_penalty=1.1,
34
- do_sample=True,
35
- return_full_text=False
36
- )
37
 
38
- self.llm = HuggingFacePipeline(pipeline=pipe)
 
 
 
 
39
  self.vector_store = vector_store
 
 
40
 
41
- print("βœ… LLM Client Ready!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def analyze(self, text, context_chunks=None):
44
  """
45
- Analyze text using LangChain RetrievalQA with MMR
46
  """
47
  if not self.vector_store:
48
  return "❌ Vector Store not initialized."
49
 
50
  # Custom Prompt Template
51
- template = """You are a cybersecurity expert. Task: Determine whether the input is 'PHISHING' or 'BENIGN' (Safe). Respond in the following format:
 
 
52
  LABEL: [PHISHING or BENIGN]
53
- PREDICTION: [Probability of LABEL]
54
- EXPLANATION: [A brief Vietnamese explanation of a few sentences for reason]
55
 
56
  Context:
57
  {context}
58
-
 
59
  Input:
60
  {question}
61
 
62
- Short Analysis:"""
 
 
 
63
 
64
  PROMPT = PromptTemplate(
65
  template=template,
66
  input_variables=["context", "question"]
67
  )
68
 
69
- retriever_config = {
70
- "k": 3,
71
- "fetch_k": 10,
72
- "lambda_mult": 0.6
73
- }
74
-
75
  # Create QA Chain
76
  qa_chain = RetrievalQA.from_chain_type(
77
  llm=self.llm,
78
  chain_type="stuff",
79
- retriever=self.vector_store.as_retriever(
80
- search_type="mmr",
81
- search_kwargs=retriever_config
82
- ),
83
  chain_type_kwargs={"prompt": PROMPT}
84
  )
85
 
86
  try:
87
- print("πŸ€– Generating response (LangChain + MMR)...")
88
  response = qa_chain.invoke(text)
89
-
90
- # Explicit Garbage Collection
91
- import gc
92
- gc.collect()
93
-
94
  return response['result']
95
  except Exception as e:
96
- return f"❌ Error: {str(e)}"
 
1
+ import os
2
+ import requests
3
+ from huggingface_hub import hf_hub_download
4
+ from langchain.llms.base import LLM
5
  from langchain.chains import RetrievalQA
6
  from langchain_core.prompts import PromptTemplate
7
+ from typing import Any, List, Optional, Mapping
8
 
9
+ # --- Custom LangChain LLM Wrapper for Hybrid Approach ---
10
+ class HybridLLM(LLM):
11
+ api_url: str = ""
12
+ local_llm: Any = None
13
+
14
+ @property
15
+ def _llm_type(self) -> str:
16
+ return "hybrid_llm"
17
+
18
+ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
19
+ # 1. Try Colab API first
20
+ if self.api_url:
21
+ try:
22
+ print(f"🌐 Calling Colab API: {self.api_url}")
23
+ response = requests.post(
24
+ f"{self.api_url}/generate",
25
+ json={"prompt": prompt, "max_tokens": 512},
26
+ timeout=30 # 30s timeout
27
+ )
28
+ if response.status_code == 200:
29
+ return response.json()["response"]
30
+ else:
31
+ print(f"⚠️ API Error {response.status_code}: {response.text}")
32
+ except Exception as e:
33
+ print(f"⚠️ API Connection Failed: {e}")
34
 
35
+ # 2. Fallback to Local GGUF
36
+ if self.local_llm:
37
+ print("πŸ’» Using Local GGUF Fallback...")
38
+ # Llama-cpp-python expects prompt in specific format or raw
39
+ # We'll pass the prompt directly
40
+ output = self.local_llm(
41
+ prompt,
42
+ max_tokens=512,
43
+ stop=["<|im_end|>", "User:", "Input:"],
44
+ echo=False
45
+ )
46
+ return output['choices'][0]['text']
47
 
48
+ return "❌ Error: No working LLM available (API failed and no local model)."
 
 
 
 
 
 
 
49
 
50
+ @property
51
+ def _identifying_params(self) -> Mapping[str, Any]:
52
+ return {"api_url": self.api_url}
 
 
 
 
 
 
 
 
 
53
 
54
+ class LLMClient:
55
+ def __init__(self, vector_store=None):
56
+ """
57
+ Initialize Hybrid LLM Client
58
+ """
59
  self.vector_store = vector_store
60
+ self.api_url = os.environ.get("COLAB_API_URL", "") # Get from Env Var
61
+ self.local_llm = None
62
 
63
+ # Initialize Local GGUF (always load as backup or if API missing)
64
+ # We load it lazily or eagerly depending on memory.
65
+ # Since user has 16GB RAM, we can load a 2B model easily.
66
+ try:
67
+ print("πŸ“‚ Loading Local Qwen3-VL-2B-Thinking (GGUF)...")
68
+ from llama_cpp import Llama
69
+
70
+ model_name = "Qwen/Qwen2.5-VL-3B-Thinking-GGUF" # Fallback to a known working GGUF if Qwen3 not found, but user asked for Qwen3
71
+ # NOTE: As of now, Qwen3-VL GGUF might be under a specific repo.
72
+ # Let's use a generic search or specific path if known.
73
+ # User specified: Qwen/Qwen3-VL-2B-Thinking-GGUF
74
+ # We will try to download it.
75
+
76
+ repo_id = "Qwen/Qwen3-VL-2B-Thinking-GGUF"
77
+
78
+ model_repo = "Qwen/Qwen3-VL-2B-Thinking-GGUF"
79
+ filename = "Qwen3VL-2B-Thinking-Q4_K_M.gguf"
80
+
81
+ model_path = hf_hub_download(
82
+ repo_id=model_repo,
83
+ filename=filename
84
+ )
85
+
86
+ self.local_llm = Llama(
87
+ model_path=model_path,
88
+ n_ctx=2048,
89
+ n_threads=2, # Use 2 vCPUs
90
+ verbose=False
91
+ )
92
+ print("βœ… Local GGUF Model Ready!")
93
+
94
+ except Exception as e:
95
+ print(f"⚠️ Could not load local GGUF: {e}")
96
+
97
+ # Create Hybrid LangChain Wrapper
98
+ self.llm = HybridLLM(api_url=self.api_url, local_llm=self.local_llm)
99
 
100
  def analyze(self, text, context_chunks=None):
101
  """
102
+ Analyze text using LangChain RetrievalQA
103
  """
104
  if not self.vector_store:
105
  return "❌ Vector Store not initialized."
106
 
107
  # Custom Prompt Template
108
+ template = """<|im_start|>system
109
+ You are a cybersecurity expert. Task: Determine whether the input is 'PHISHING' or 'BENIGN' (Safe).
110
+ Respond in the following format:
111
  LABEL: [PHISHING or BENIGN]
112
+ EXPLANATION: [A brief Vietnamese explanation]
 
113
 
114
  Context:
115
  {context}
116
+ <|im_end|>
117
+ <|im_start|>user
118
  Input:
119
  {question}
120
 
121
+ Short Analysis:
122
+ <|im_end|>
123
+ <|im_start|>assistant
124
+ """
125
 
126
  PROMPT = PromptTemplate(
127
  template=template,
128
  input_variables=["context", "question"]
129
  )
130
 
 
 
 
 
 
 
131
  # Create QA Chain
132
  qa_chain = RetrievalQA.from_chain_type(
133
  llm=self.llm,
134
  chain_type="stuff",
135
+ retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),
 
 
 
136
  chain_type_kwargs={"prompt": PROMPT}
137
  )
138
 
139
  try:
140
+ print("πŸ€– Generating response...")
141
  response = qa_chain.invoke(text)
 
 
 
 
 
142
  return response['result']
143
  except Exception as e:
144
+ return f"❌ Error: {str(e)}"