dungeon29 commited on
Commit
47c8ad8
Β·
verified Β·
1 Parent(s): 3797dd4

Update llm_client.py

Browse files
Files changed (1) hide show
  1. llm_client.py +85 -85
llm_client.py CHANGED
@@ -1,85 +1,85 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
- from langchain_huggingface import HuggingFacePipeline
3
- from langchain.chains import RetrievalQA
4
- from langchain_core.prompts import PromptTemplate
5
- import torch
6
-
7
- class LLMClient:
8
- def __init__(self, vector_store=None):
9
- """
10
- Initialize Qwen2.5-3B-Instruct with LangChain
11
- """
12
- print("πŸ”· Loading Qwen2.5-3B-Instruct (LangChain)...")
13
-
14
- model_name = "Qwen/Qwen2.5-1.5B-Instruct"
15
-
16
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
17
- self.model = AutoModelForCausalLM.from_pretrained(
18
- model_name,
19
- torch_dtype=torch.bfloat16,
20
- device_map="cpu",
21
- low_cpu_mem_usage=True,
22
- trust_remote_code=True
23
- )
24
-
25
- # Create HF Pipeline
26
- pipe = pipeline(
27
- "text-generation",
28
- model=self.model,
29
- tokenizer=self.tokenizer,
30
- max_new_tokens=256, # Reduced to save memory
31
- temperature=0.7,
32
- top_p=0.9,
33
- repetition_penalty=1.1,
34
- do_sample=True
35
- )
36
-
37
- self.llm = HuggingFacePipeline(pipeline=pipe)
38
- self.vector_store = vector_store
39
-
40
- print("βœ… LLM Client Ready!")
41
-
42
- def analyze(self, text, context_chunks=None):
43
- """
44
- Analyze text using LangChain RetrievalQA
45
- """
46
- if not self.vector_store:
47
- return "❌ Vector Store not initialized."
48
-
49
- # Custom Prompt Template
50
- template = """You are a cybersecurity expert specializing in phishing detection.
51
- Use the following pieces of context to analyze the input.
52
- If the input is in Vietnamese, respond in Vietnamese.
53
-
54
- Context:
55
- {context}
56
-
57
- Input to Analyze:
58
- {question}
59
-
60
- Analysis:"""
61
-
62
- PROMPT = PromptTemplate(
63
- template=template,
64
- input_variables=["context", "question"]
65
- )
66
-
67
- # Create QA Chain
68
- qa_chain = RetrievalQA.from_chain_type(
69
- llm=self.llm,
70
- chain_type="stuff",
71
- retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),
72
- chain_type_kwargs={"prompt": PROMPT}
73
- )
74
-
75
- try:
76
- print("πŸ€– Generating response (LangChain)...")
77
- response = qa_chain.invoke(text)
78
-
79
- # Explicit Garbage Collection
80
- import gc
81
- gc.collect()
82
-
83
- return response['result']
84
- except Exception as e:
85
- return f"❌ Error: {str(e)}"
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
+ from langchain_huggingface import HuggingFacePipeline
3
+ from langchain.chains import RetrievalQA
4
+ from langchain_core.prompts import PromptTemplate
5
+ import torch
6
+
7
+ class LLMClient:
8
+ def __init__(self, vector_store=None):
9
+ """
10
+ Initialize Qwen2.5-3B-Instruct with LangChain
11
+ """
12
+ print("πŸ”· Loading Qwen2.5-3B-Instruct (LangChain)...")
13
+
14
+ model_name = "Qwen/Qwen2.5-1.5B-Instruct"
15
+
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
17
+ self.model = AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ torch_dtype=torch.bfloat16,
20
+ device_map="cpu",
21
+ low_cpu_mem_usage=True,
22
+ trust_remote_code=True
23
+ )
24
+
25
+ # Create HF Pipeline
26
+ pipe = pipeline(
27
+ "text-generation",
28
+ model=self.model,
29
+ tokenizer=self.tokenizer,
30
+ max_new_tokens=512,
31
+ temperature=0.5,
32
+ top_p=0.9,
33
+ repetition_penalty=1.1,
34
+ do_sample=True
35
+ )
36
+
37
+ self.llm = HuggingFacePipeline(pipeline=pipe)
38
+ self.vector_store = vector_store
39
+
40
+ print("βœ… LLM Client Ready!")
41
+
42
+ def analyze(self, text, context_chunks=None):
43
+ """
44
+ Analyze text using LangChain RetrievalQA
45
+ """
46
+ if not self.vector_store:
47
+ return "❌ Vector Store not initialized."
48
+
49
+ # Custom Prompt Template
50
+ template = """You are a cybersecurity expert specializing in phishing detection.
51
+ Use the following pieces of context to analyze the input.
52
+ If the input is in Vietnamese, respond in Vietnamese.
53
+
54
+ Context:
55
+ {context}
56
+
57
+ Input to Analyze:
58
+ {question}
59
+
60
+ Analysis:"""
61
+
62
+ PROMPT = PromptTemplate(
63
+ template=template,
64
+ input_variables=["context", "question"]
65
+ )
66
+
67
+ # Create QA Chain
68
+ qa_chain = RetrievalQA.from_chain_type(
69
+ llm=self.llm,
70
+ chain_type="stuff",
71
+ retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),
72
+ chain_type_kwargs={"prompt": PROMPT}
73
+ )
74
+
75
+ try:
76
+ print("πŸ€– Generating response (LangChain)...")
77
+ response = qa_chain.invoke(text)
78
+
79
+ # Explicit Garbage Collection
80
+ import gc
81
+ gc.collect()
82
+
83
+ return response['result']
84
+ except Exception as e:
85
+ return f"❌ Error: {str(e)}"