File size: 10,709 Bytes
223ef32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 |
"""
React Agent for Cyber Knowledge Base
This script creates a ReAct agent using LangGraph that can use the CyberKnowledgeBase
search method as a tool to retrieve MITRE ATT&CK techniques.
"""
import os
import sys
import json
from typing import List, Dict, Any, Union, Optional
from pathlib import Path
# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langgraph.prebuilt import create_react_agent
from langchain.chat_models import init_chat_model
from langchain_core.language_models.chat_models import BaseChatModel
# Import local modules
from src.knowledge_base.cyber_knowledge_base import CyberKnowledgeBase
# Initialize the knowledge base
def init_knowledge_base(
persist_dir: str = "./cyber_knowledge_base",
) -> CyberKnowledgeBase:
"""Initialize and load the cyber knowledge base"""
kb = CyberKnowledgeBase()
# Try to load existing knowledge base
if kb.load_knowledge_base(persist_dir):
print("[SUCCESS] Loaded existing knowledge base")
return kb
else:
print("[WARNING] Could not load knowledge base, please build it first")
print("Run: python src/scripts/build_cyber_database.py")
sys.exit(1)
def _format_results_as_json(results) -> List[Dict[str, Any]]:
"""Format search results as structured JSON"""
output = []
for doc in results:
technique_info = {
"attack_id": doc.metadata.get("attack_id", "Unknown"),
"name": doc.metadata.get("name", "Unknown"),
"tactics": [
t.strip()
for t in doc.metadata.get("tactics", "").split(",")
if t.strip()
],
"platforms": [
p.strip()
for p in doc.metadata.get("platforms", "").split(",")
if p.strip()
],
"description": (
doc.page_content.split("Description: ")[-1]
if "Description: " in doc.page_content
else doc.page_content
),
"relevance_score": doc.metadata.get(
"relevance_score", None
), # From reranking
}
output.append(technique_info)
return output
def create_agent(llm_client: BaseChatModel, kb: CyberKnowledgeBase):
"""Create a ReAct agent with LangGraph"""
# Define the tools bound to the provided knowledge base
@tool
def search_techniques(
queries: Union[str, List[str]],
top_k: int = 5,
rerank_query: Optional[str] = None,
) -> str:
"""
Search for MITRE ATT&CK techniques using the knowledge base.
This tool searches a vector database containing MITRE ATT&CK technique descriptions,
including their tactics, platforms, and detailed behavioral information. Each technique
in the database has its full description embedded for semantic similarity search.
Args:
queries: Single search query string OR list of query strings.
rerank_query: Optional tag echoed in the output for transparency.
top_k: Number of results to return per query (default: 10)
Returns:
JSON string with results grouped per query. Each group contains:
- query: The original query string
- techniques: List of technique objects (attack_id, name, tactics, platforms, description, relevance_score)
- total_results: Number of techniques in this group
"""
try:
# Convert single query to list for uniform processing
if isinstance(queries, str):
queries = [queries]
# Run a normal search once per query and keep results associated with that query
results_by_query: List[Dict[str, Any]] = []
for i, q in enumerate(queries, 1):
print(f"[INFO] Query {i}/{len(queries)}: '{q}'")
per_query_results = kb.search(q, top_k=top_k)
techniques = _format_results_as_json(per_query_results)
results_by_query.append(
{
"query": q,
"techniques": techniques,
"total_results": len(techniques),
}
)
# If all queries returned no results
if all(len(group["techniques"]) == 0 for group in results_by_query):
return json.dumps(
{
"results_by_query": results_by_query,
"message": "No techniques found matching the provided queries.",
},
indent=2,
)
return json.dumps(
{
"results_by_query": results_by_query,
"queries_used": queries,
"rerank_query": rerank_query,
},
indent=2,
)
except Exception as e:
return json.dumps(
{
"error": str(e),
"techniques": [],
"message": "Error occurred during search",
},
indent=2,
)
tools = [search_techniques]
# Define the system prompt for the agent
system_prompt = """
You are a cybersecurity analyst assistant that helps answer questions about MITRE ATT&CK techniques.
You have access to a knowledge base of MITRE ATT&CK techniques that you can search.
Use the search_techniques tool to find relevant techniques based on the user's query.
"""
# Get the LLM from the client
llm = llm_client
# Create the React agent
agent_runnable = create_react_agent(llm, tools, prompt=system_prompt)
return agent_runnable
def run_test_queries(agent):
"""Run the agent with some test queries"""
# Test queries
test_queries = [
"What techniques are used for credential dumping?",
"How do attackers use process injection for defense evasion?",
"What are common persistence techniques on Windows systems?",
]
# Run the agent with test queries
for i, query in enumerate(test_queries, 1):
print(f"\n\n===== Test Query {i}: '{query}' =====\n")
# Create the input state
state = {"messages": [HumanMessage(content=query)]}
# Run the agent
result = agent.invoke(state)
# Print all intermediate messages
print("[TRACE] Conversation messages:")
for message in result["messages"]:
if isinstance(message, HumanMessage):
print(f"- [Human] {message.content}")
elif isinstance(message, AIMessage):
agent_name = getattr(message, "name", None) or "agent"
print(f"- [Agent:{agent_name}] {message.content}")
if "function_call" in message.additional_kwargs:
fc = message.additional_kwargs["function_call"]
print(f" [ToolCall] {fc.get('name')}: {fc.get('arguments')}")
elif isinstance(message, ToolMessage):
tool_name = getattr(message, "name", None) or "tool"
print(f"- [Tool:{tool_name}] {message.content}")
def interactive_mode(agent):
"""Run the agent in interactive mode"""
print("\n\n===== Interactive Mode =====")
print("Type 'exit' or 'quit' to end the session\n")
# Keep track of conversation history
messages = []
while True:
# Get user input
user_input = input("\nYou: ")
# Check if user wants to exit
if user_input.lower() in ["exit", "quit"]:
print("Exiting interactive mode...")
break
# Add user message to history
messages.append(HumanMessage(content=user_input))
# Create the input state
state = {"messages": messages.copy()}
# Run the agent
try:
result = agent.invoke(state)
# Update conversation history with agent's response
messages = result["messages"]
# Print the agent's response
for message in messages:
if isinstance(message, AIMessage):
print("\n" + "=" * 50)
print(f"\nAgent: {message.content}")
if "function_call" in message.additional_kwargs:
print(
"Function call:",
message.additional_kwargs["function_call"]["name"],
)
print(
"Arguments:",
message.additional_kwargs["function_call"]["arguments"],
)
print("-" * 50)
if isinstance(message, ToolMessage):
print("Tool output:", message.content)
except Exception as e:
print(f"Error: {str(e)}")
def main():
"""Main function to run the agent"""
global kb
# Initialize the knowledge base
kb_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"cyber_knowledge_base",
)
kb = init_knowledge_base(kb_path)
# Print KB stats
stats = kb.get_stats()
print(
f"Knowledge base loaded with {stats.get('total_techniques', 'unknown')} techniques"
)
# Initialize the LLM client (using environment variables)
llm_client = init_chat_model("google_genai:gemini-2.0-flash", temperature=0.2)
# Create the agent
agent = create_agent(llm_client, kb)
# Parse command line arguments
import argparse
parser = argparse.ArgumentParser(description="Run the Cyber KB React Agent")
parser.add_argument(
"--interactive", "-i", action="store_true", help="Run in interactive mode"
)
parser.add_argument("--test", "-t", action="store_true", help="Run test queries")
args = parser.parse_args()
# Run in the appropriate mode
if args.interactive:
interactive_mode(agent)
elif args.test:
run_test_queries(agent)
else:
# Default: run interactive mode
interactive_mode(agent)
if __name__ == "__main__":
main()
|