|
|
"""
|
|
|
React Agent for Cyber Knowledge Base
|
|
|
|
|
|
This script creates a ReAct agent using LangGraph that can use the CyberKnowledgeBase
|
|
|
search method as a tool to retrieve MITRE ATT&CK techniques.
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
import json
|
|
|
from typing import List, Dict, Any, Union, Optional
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
|
|
|
|
from langchain_core.tools import tool
|
|
|
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
|
|
|
from langgraph.prebuilt import create_react_agent
|
|
|
from langchain.chat_models import init_chat_model
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
|
|
|
|
|
|
|
from src.knowledge_base.cyber_knowledge_base import CyberKnowledgeBase
|
|
|
|
|
|
|
|
|
|
|
|
def init_knowledge_base(
|
|
|
persist_dir: str = "./cyber_knowledge_base",
|
|
|
) -> CyberKnowledgeBase:
|
|
|
"""Initialize and load the cyber knowledge base"""
|
|
|
kb = CyberKnowledgeBase()
|
|
|
|
|
|
|
|
|
if kb.load_knowledge_base(persist_dir):
|
|
|
print("[SUCCESS] Loaded existing knowledge base")
|
|
|
return kb
|
|
|
else:
|
|
|
print("[WARNING] Could not load knowledge base, please build it first")
|
|
|
print("Run: python src/scripts/build_cyber_database.py")
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
def _format_results_as_json(results) -> List[Dict[str, Any]]:
|
|
|
"""Format search results as structured JSON"""
|
|
|
output = []
|
|
|
for doc in results:
|
|
|
technique_info = {
|
|
|
"attack_id": doc.metadata.get("attack_id", "Unknown"),
|
|
|
"name": doc.metadata.get("name", "Unknown"),
|
|
|
"tactics": [
|
|
|
t.strip()
|
|
|
for t in doc.metadata.get("tactics", "").split(",")
|
|
|
if t.strip()
|
|
|
],
|
|
|
"platforms": [
|
|
|
p.strip()
|
|
|
for p in doc.metadata.get("platforms", "").split(",")
|
|
|
if p.strip()
|
|
|
],
|
|
|
"description": (
|
|
|
doc.page_content.split("Description: ")[-1]
|
|
|
if "Description: " in doc.page_content
|
|
|
else doc.page_content
|
|
|
),
|
|
|
"relevance_score": doc.metadata.get(
|
|
|
"relevance_score", None
|
|
|
),
|
|
|
}
|
|
|
output.append(technique_info)
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
def create_agent(llm_client: BaseChatModel, kb: CyberKnowledgeBase):
|
|
|
"""Create a ReAct agent with LangGraph"""
|
|
|
|
|
|
|
|
|
@tool
|
|
|
def search_techniques(
|
|
|
queries: Union[str, List[str]],
|
|
|
top_k: int = 5,
|
|
|
rerank_query: Optional[str] = None,
|
|
|
) -> str:
|
|
|
"""
|
|
|
Search for MITRE ATT&CK techniques using the knowledge base.
|
|
|
|
|
|
This tool searches a vector database containing MITRE ATT&CK technique descriptions,
|
|
|
including their tactics, platforms, and detailed behavioral information. Each technique
|
|
|
in the database has its full description embedded for semantic similarity search.
|
|
|
|
|
|
Args:
|
|
|
queries: Single search query string OR list of query strings.
|
|
|
rerank_query: Optional tag echoed in the output for transparency.
|
|
|
top_k: Number of results to return per query (default: 10)
|
|
|
|
|
|
Returns:
|
|
|
JSON string with results grouped per query. Each group contains:
|
|
|
- query: The original query string
|
|
|
- techniques: List of technique objects (attack_id, name, tactics, platforms, description, relevance_score)
|
|
|
- total_results: Number of techniques in this group
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
if isinstance(queries, str):
|
|
|
queries = [queries]
|
|
|
|
|
|
|
|
|
results_by_query: List[Dict[str, Any]] = []
|
|
|
for i, q in enumerate(queries, 1):
|
|
|
print(f"[INFO] Query {i}/{len(queries)}: '{q}'")
|
|
|
per_query_results = kb.search(q, top_k=top_k)
|
|
|
techniques = _format_results_as_json(per_query_results)
|
|
|
results_by_query.append(
|
|
|
{
|
|
|
"query": q,
|
|
|
"techniques": techniques,
|
|
|
"total_results": len(techniques),
|
|
|
}
|
|
|
)
|
|
|
|
|
|
|
|
|
if all(len(group["techniques"]) == 0 for group in results_by_query):
|
|
|
return json.dumps(
|
|
|
{
|
|
|
"results_by_query": results_by_query,
|
|
|
"message": "No techniques found matching the provided queries.",
|
|
|
},
|
|
|
indent=2,
|
|
|
)
|
|
|
|
|
|
return json.dumps(
|
|
|
{
|
|
|
"results_by_query": results_by_query,
|
|
|
"queries_used": queries,
|
|
|
"rerank_query": rerank_query,
|
|
|
},
|
|
|
indent=2,
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
return json.dumps(
|
|
|
{
|
|
|
"error": str(e),
|
|
|
"techniques": [],
|
|
|
"message": "Error occurred during search",
|
|
|
},
|
|
|
indent=2,
|
|
|
)
|
|
|
|
|
|
tools = [search_techniques]
|
|
|
|
|
|
|
|
|
system_prompt = """
|
|
|
You are a cybersecurity analyst assistant that helps answer questions about MITRE ATT&CK techniques.
|
|
|
|
|
|
You have access to a knowledge base of MITRE ATT&CK techniques that you can search.
|
|
|
Use the search_techniques tool to find relevant techniques based on the user's query.
|
|
|
"""
|
|
|
|
|
|
|
|
|
llm = llm_client
|
|
|
|
|
|
|
|
|
agent_runnable = create_react_agent(llm, tools, prompt=system_prompt)
|
|
|
|
|
|
return agent_runnable
|
|
|
|
|
|
|
|
|
def run_test_queries(agent):
|
|
|
"""Run the agent with some test queries"""
|
|
|
|
|
|
|
|
|
test_queries = [
|
|
|
"What techniques are used for credential dumping?",
|
|
|
"How do attackers use process injection for defense evasion?",
|
|
|
"What are common persistence techniques on Windows systems?",
|
|
|
]
|
|
|
|
|
|
|
|
|
for i, query in enumerate(test_queries, 1):
|
|
|
print(f"\n\n===== Test Query {i}: '{query}' =====\n")
|
|
|
|
|
|
|
|
|
state = {"messages": [HumanMessage(content=query)]}
|
|
|
|
|
|
|
|
|
result = agent.invoke(state)
|
|
|
|
|
|
|
|
|
print("[TRACE] Conversation messages:")
|
|
|
for message in result["messages"]:
|
|
|
if isinstance(message, HumanMessage):
|
|
|
print(f"- [Human] {message.content}")
|
|
|
elif isinstance(message, AIMessage):
|
|
|
agent_name = getattr(message, "name", None) or "agent"
|
|
|
print(f"- [Agent:{agent_name}] {message.content}")
|
|
|
if "function_call" in message.additional_kwargs:
|
|
|
fc = message.additional_kwargs["function_call"]
|
|
|
print(f" [ToolCall] {fc.get('name')}: {fc.get('arguments')}")
|
|
|
elif isinstance(message, ToolMessage):
|
|
|
tool_name = getattr(message, "name", None) or "tool"
|
|
|
print(f"- [Tool:{tool_name}] {message.content}")
|
|
|
|
|
|
|
|
|
def interactive_mode(agent):
|
|
|
"""Run the agent in interactive mode"""
|
|
|
print("\n\n===== Interactive Mode =====")
|
|
|
print("Type 'exit' or 'quit' to end the session\n")
|
|
|
|
|
|
|
|
|
messages = []
|
|
|
|
|
|
while True:
|
|
|
|
|
|
user_input = input("\nYou: ")
|
|
|
|
|
|
|
|
|
if user_input.lower() in ["exit", "quit"]:
|
|
|
print("Exiting interactive mode...")
|
|
|
break
|
|
|
|
|
|
|
|
|
messages.append(HumanMessage(content=user_input))
|
|
|
|
|
|
|
|
|
state = {"messages": messages.copy()}
|
|
|
|
|
|
|
|
|
try:
|
|
|
result = agent.invoke(state)
|
|
|
|
|
|
|
|
|
messages = result["messages"]
|
|
|
|
|
|
|
|
|
for message in messages:
|
|
|
if isinstance(message, AIMessage):
|
|
|
print("\n" + "=" * 50)
|
|
|
print(f"\nAgent: {message.content}")
|
|
|
if "function_call" in message.additional_kwargs:
|
|
|
print(
|
|
|
"Function call:",
|
|
|
message.additional_kwargs["function_call"]["name"],
|
|
|
)
|
|
|
print(
|
|
|
"Arguments:",
|
|
|
message.additional_kwargs["function_call"]["arguments"],
|
|
|
)
|
|
|
|
|
|
print("-" * 50)
|
|
|
|
|
|
if isinstance(message, ToolMessage):
|
|
|
print("Tool output:", message.content)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Error: {str(e)}")
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Main function to run the agent"""
|
|
|
global kb
|
|
|
|
|
|
|
|
|
kb_path = os.path.join(
|
|
|
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
|
|
"cyber_knowledge_base",
|
|
|
)
|
|
|
kb = init_knowledge_base(kb_path)
|
|
|
|
|
|
|
|
|
stats = kb.get_stats()
|
|
|
print(
|
|
|
f"Knowledge base loaded with {stats.get('total_techniques', 'unknown')} techniques"
|
|
|
)
|
|
|
|
|
|
|
|
|
llm_client = init_chat_model("google_genai:gemini-2.0-flash", temperature=0.2)
|
|
|
|
|
|
|
|
|
agent = create_agent(llm_client, kb)
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Run the Cyber KB React Agent")
|
|
|
parser.add_argument(
|
|
|
"--interactive", "-i", action="store_true", help="Run in interactive mode"
|
|
|
)
|
|
|
parser.add_argument("--test", "-t", action="store_true", help="Run test queries")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
if args.interactive:
|
|
|
interactive_mode(agent)
|
|
|
elif args.test:
|
|
|
run_test_queries(agent)
|
|
|
else:
|
|
|
|
|
|
interactive_mode(agent)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|