Spaces:
Sleeping
Sleeping
Muhammed Machrouh
commited on
Commit
•
4531c67
1
Parent(s):
1361fdd
Initial files
Browse files- .DS_Store +0 -0
- agents/investigator.py +113 -0
- agents/router.py +111 -0
- app.py +53 -0
- template.env +4 -0
- tools/coder_tool.py +30 -0
- tools/cve_avd_tool.py +143 -0
- tools/cve_search_tool.py +38 -0
- tools/elastic_tool.py +57 -0
- tools/log_tool.py +1 -0
- tools/misp_tool.py +63 -0
- tools/mitre_tool.py +49 -0
- tools/scraper_tools.py +31 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
agents/investigator.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.llms import Ollama
|
2 |
+
from langchain_community.chat_models import ChatOllama
|
3 |
+
|
4 |
+
from langchain import hub
|
5 |
+
|
6 |
+
from agentops.langchain_callback_handler import LangchainCallbackHandler as AgentOpsLangchainCallbackHandler
|
7 |
+
|
8 |
+
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
|
9 |
+
|
10 |
+
from tools.cve_avd_tool import CVESearchTool
|
11 |
+
from tools.misp_tool import MispTool
|
12 |
+
from tools.coder_tool import CoderTool
|
13 |
+
from tools.mitre_tool import MitreTool
|
14 |
+
|
15 |
+
from langchain.agents import initialize_agent, AgentType, load_tools
|
16 |
+
from langchain.evaluation import load_evaluator
|
17 |
+
|
18 |
+
|
19 |
+
from dotenv import load_dotenv
|
20 |
+
import os
|
21 |
+
import re
|
22 |
+
|
23 |
+
load_dotenv(override=True)
|
24 |
+
|
25 |
+
|
26 |
+
llm = Ollama(model="openhermes", base_url=os.getenv('OLLAMA_HOST'), temperature=0.3, num_predict=8192, num_ctx=8192)
|
27 |
+
wrn = Ollama(model="openfc", base_url=os.getenv('OLLAMA_HOST'))
|
28 |
+
llama3 = Ollama(model="llama3", base_url=os.getenv('OLLAMA_HOST'), temperature=0.3)
|
29 |
+
|
30 |
+
command_r = Ollama(model="command-r", base_url=os.getenv('OLLAMA_HOST'), temperature=0.1, num_ctx=8192)
|
31 |
+
hermes_llama3 = Ollama(model="adrienbrault/nous-hermes2pro-llama3-8b:q4_K_M", base_url=os.getenv('OLLAMA_HOST'), temperature=0.3, num_ctx=32768)
|
32 |
+
yarn_mistral_128k = Ollama(model="yarn-mistral-modified", base_url=os.getenv('OLLAMA_HOST'), temperature=0.1, num_ctx=65536, system="""""")
|
33 |
+
|
34 |
+
chat_llm = ChatOllama(model="openhermes", base_url=os.getenv('OLLAMA_HOST'), num_predict=-1)
|
35 |
+
|
36 |
+
|
37 |
+
cve_search_tool = CVESearchTool().cvesearch
|
38 |
+
fetch_cve_tool = CVESearchTool().fetchcve
|
39 |
+
misp_search_tool = MispTool().search
|
40 |
+
misp_search_by_date_tool = MispTool().search_by_date
|
41 |
+
misp_search_by_event_id_tool = MispTool().search_by_event_id
|
42 |
+
coder_tool = CoderTool().code_generation_tool
|
43 |
+
|
44 |
+
get_technique_by_id = MitreTool().get_technique_by_id
|
45 |
+
get_technique_by_name = MitreTool().get_technique_by_name
|
46 |
+
get_malware_by_name = MitreTool().get_malware_by_name
|
47 |
+
get_tactic_by_keyword = MitreTool().get_tactic_by_keyword
|
48 |
+
|
49 |
+
tools = [cve_search_tool, fetch_cve_tool, misp_search_tool, misp_search_by_date_tool, misp_search_by_event_id_tool,
|
50 |
+
coder_tool, get_technique_by_id, get_technique_by_name, get_malware_by_name, get_tactic_by_keyword]
|
51 |
+
|
52 |
+
# conversational agent memory
|
53 |
+
memory = ConversationBufferWindowMemory(
|
54 |
+
memory_key='chat_history',
|
55 |
+
k=4,
|
56 |
+
return_messages=True
|
57 |
+
)
|
58 |
+
|
59 |
+
agentops_handler = AgentOpsLangchainCallbackHandler(api_key=os.getenv("AGENTOPS_API_KEY"), tags=['Langchain Example'])
|
60 |
+
|
61 |
+
#Error handling
|
62 |
+
def _handle_error(error) -> str:
|
63 |
+
|
64 |
+
pattern = r'```(?!json)(.*?)```'
|
65 |
+
match = re.search(pattern, str(error), re.DOTALL)
|
66 |
+
if match:
|
67 |
+
return "The answer contained a code blob which caused the parsing to fail, i recovered the code blob. Just use it to answer the user question: " + match.group(1)
|
68 |
+
else:
|
69 |
+
return llm.invoke(f"""Try to summarize and explain the following error into 1 short and consice sentence and give a small indication to correct the error: {error} """)
|
70 |
+
|
71 |
+
|
72 |
+
prompt = hub.pull("hwchase17/react-chat-json")
|
73 |
+
# create our agent
|
74 |
+
conversational_agent = initialize_agent(
|
75 |
+
# agent="chat-conversational-react-description",
|
76 |
+
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
77 |
+
tools=tools,
|
78 |
+
prompt=prompt,
|
79 |
+
llm=llm,
|
80 |
+
verbose=True,
|
81 |
+
max_iterations=5,
|
82 |
+
memory=memory,
|
83 |
+
early_stopping_method='generate',
|
84 |
+
# callbacks=[agentops_handler],
|
85 |
+
handle_parsing_errors=_handle_error,
|
86 |
+
return_intermediate_steps=False,
|
87 |
+
max_execution_time=40,
|
88 |
+
)
|
89 |
+
|
90 |
+
evaluator = load_evaluator("trajectory", llm=chat_llm)
|
91 |
+
|
92 |
+
|
93 |
+
# conversational_agent.agent.llm_chain.prompt.messages[0].prompt.template = """
|
94 |
+
# 'Respond to the human as helpfully and accurately as possible.
|
95 |
+
# You should use the tools available to you to help answer the question.
|
96 |
+
# Your final answer should be technical, well explained, and accurate.
|
97 |
+
# You have access to the following tools:\n\n\n\nUse a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\nValid "action" values: "Final Answer" or \n\nProvide only ONE action per $JSON_BLOB, as shown:\n\n```\n{{\n "action": $TOOL_NAME,\n "action_input": $INPUT\n}}\n```\n\nFollow this format:\n\nQuestion: input question to answer\nThought: consider previous and subsequent steps\nAction:\n```\n$JSON_BLOB\n```\nObservation: action result\n... (repeat Thought/Action/Observation N times)\nThought: I know what to respond\nAction:\n```\n{{\n "action": "Final Answer",\n "action_input": "Final response to human"\n}}\n```\n\nBegin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.\nThought:'
|
98 |
+
# """
|
99 |
+
|
100 |
+
template = conversational_agent.agent.llm_chain.prompt.messages[0].prompt.template
|
101 |
+
|
102 |
+
conversational_agent.agent.llm_chain.prompt.messages[0].prompt.template = """You are a cyber security analyst, you role is to respond to the human queries in a technical way while providing detailed explanations when providing final answer.""" + template
|
103 |
+
|
104 |
+
def invoke(input_text):
|
105 |
+
results = conversational_agent({"input":input_text})
|
106 |
+
# evaluation_result = evaluator.evaluate_agent_trajectory(
|
107 |
+
# prediction=results["output"],
|
108 |
+
# input=results["input"],
|
109 |
+
# agent_trajectory=results["intermediate_steps"],
|
110 |
+
# )
|
111 |
+
|
112 |
+
# print(evaluation_result)
|
113 |
+
return results['output']
|
agents/router.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.llms import Ollama
|
2 |
+
|
3 |
+
from langchain import hub
|
4 |
+
|
5 |
+
from agentops.langchain_callback_handler import LangchainCallbackHandler as AgentOpsLangchainCallbackHandler
|
6 |
+
|
7 |
+
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
|
8 |
+
|
9 |
+
from langchain.agents import initialize_agent, AgentType, load_tools
|
10 |
+
|
11 |
+
from langchain.tools import StructuredTool, Tool, ShellTool
|
12 |
+
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
import os
|
15 |
+
import re, json
|
16 |
+
|
17 |
+
from langchain.agents import create_json_agent
|
18 |
+
from langchain.agents.agent_toolkits import JsonToolkit
|
19 |
+
from langchain.tools.json.tool import JsonSpec
|
20 |
+
|
21 |
+
from .investigator import *
|
22 |
+
from .investigator import invoke as investigator_invoke
|
23 |
+
|
24 |
+
load_dotenv(override=True)
|
25 |
+
|
26 |
+
|
27 |
+
llm = Ollama(model="openhermes", base_url=os.getenv('OLLAMA_HOST'), temperature=0.3, num_predict=8192, num_ctx=8192)
|
28 |
+
wrn = Ollama(model="wrn", base_url=os.getenv('OLLAMA_HOST'))
|
29 |
+
wrn = Ollama(model="openfc", base_url=os.getenv('OLLAMA_HOST'))
|
30 |
+
|
31 |
+
# def get_json_agent(json_path: str):
|
32 |
+
# with open(json_path) as f:
|
33 |
+
# data = json.load(f)
|
34 |
+
# json_spec = JsonSpec(dict_=data, max_value_length=4000)
|
35 |
+
# json_toolkit = JsonToolkit(spec=json_spec)
|
36 |
+
|
37 |
+
# json_agent = create_json_agent(
|
38 |
+
# llm=llm,
|
39 |
+
# toolkit=json_toolkit,
|
40 |
+
# verbose=True
|
41 |
+
# )
|
42 |
+
# return json_agent
|
43 |
+
|
44 |
+
# def investigate_agent():
|
45 |
+
# """
|
46 |
+
# This function will help you execute a query to find information about a security event. Just provide the request and get the response.
|
47 |
+
# Parameters:
|
48 |
+
# - request: The request to search for
|
49 |
+
# Returns:
|
50 |
+
# - The response of the search
|
51 |
+
# """
|
52 |
+
|
53 |
+
# def investigate(request: str):
|
54 |
+
# json_agent = get_json_agent("./inventory_prices_dict.json")
|
55 |
+
# result = json_agent.run(
|
56 |
+
# f"""get the price of {inventory_item} from the json file.
|
57 |
+
# Find the closest match to the item you're looking for in that json, e.g.
|
58 |
+
# if you're looking for "mahogany oak table" and that is not in the json, use "table".
|
59 |
+
# Be mindful of the format of the json - there is no list that you can access via [0], so don't try to do that
|
60 |
+
# """)
|
61 |
+
# return result
|
62 |
+
|
63 |
+
investigate_tool = Tool(name="Investigate Tool",
|
64 |
+
description="This tool will help you execute a query to find information about a security event.(Can be a MISP event, CVE, MITRE attack or technique, malware...) Just provide the request and get the response.",
|
65 |
+
func=investigator_invoke)
|
66 |
+
|
67 |
+
shell_tool = ShellTool()
|
68 |
+
tools = [investigate_tool, shell_tool]
|
69 |
+
|
70 |
+
|
71 |
+
memory = ConversationBufferWindowMemory(
|
72 |
+
memory_key='chat_history',
|
73 |
+
k=4,
|
74 |
+
return_messages=True
|
75 |
+
)
|
76 |
+
|
77 |
+
agent = initialize_agent(
|
78 |
+
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
|
79 |
+
tools=tools,
|
80 |
+
# prompt=prompt,
|
81 |
+
llm=llm,
|
82 |
+
verbose=True,
|
83 |
+
max_iterations=5,
|
84 |
+
memory=memory,
|
85 |
+
early_stopping_method='generate',
|
86 |
+
# return_intermediate_steps=True,
|
87 |
+
handle_parsing_errors=True,
|
88 |
+
max_execution_time=40,
|
89 |
+
)
|
90 |
+
|
91 |
+
template = agent.agent.llm_chain.prompt.messages[0].prompt.template
|
92 |
+
|
93 |
+
# agent.agent.llm_chain.prompt.messages[0].prompt.template = """You are a cyber security analyst called Sonic Cyber Assistant, you were built by a team of engineers at UM6P and DGSSI. you role is to respond to the human queries in a technical way while providing detailed explanations when providing final answer.
|
94 |
+
# your role is to respond to human queries in a technical manner while providing detailed explanations in your final answers. You have a set of tools at your disposal to assist in answering questions. Always delegate investigative tasks to the Investigate Tool, which will perform the investigation and provide results for you to use in your responses. If the Investigate Tool's response contains important information, include it in your answer. If it does not, use the response to formulate your answer. For executing commands, use the Shell Tool and provide the output to the user. Preserve any code blocks and links in your responses as they may contain important information. If a question is unclear, ask the user for clarification. When faced with multiple questions, answer each one separately and sequentially. Never answer questions that are not related to cybersecurity.
|
95 |
+
# """
|
96 |
+
agent.agent.llm_chain.prompt.messages[0].prompt.template = """You are a cyber security analyst called Sonic Cyber Assistant, you were built by a team of engineers at UM6P and DGSSI. you role is to respond to the human queries in a technical way while providing detailed explanations when providing final answer.
|
97 |
+
You are provided with a set of tools to help you answer the questions. Use the tools to help you answer the questions.
|
98 |
+
Always delegate the investigation to the Investigate Tool. The Investigate Tool will perform the investigation and provide the results, which you will use to answer the user's question. If the Investigate Tool's response contains some important information, answer the user's question while providing the information. If the Investigate Tool's response does not contain important information, use the Investigate Tool's response to answer the user's question.
|
99 |
+
If the user asked you to execute a command, use the Shell Tool to execute the command and provide the output to the user.
|
100 |
+
Also try to preserve any code blocks in the response as well as links, as they may contain important information.
|
101 |
+
If the question is not clear, ask the user to clarify the question.
|
102 |
+
One important thing to remember is that if the question is composed of multiple questions, answer each question separately in a sequential manner.
|
103 |
+
NEVER ANSWER QUESTIONS THAT ARE NOT RELATED TO CYBERSECURITY.
|
104 |
+
"""
|
105 |
+
# print(agent.agent.llm_chain.prompt.messages[0].prompt.template)
|
106 |
+
|
107 |
+
def invoke(input_text):
|
108 |
+
return agent({"input":input_text})
|
109 |
+
|
110 |
+
def generate_title(input_text):
|
111 |
+
return llm.invoke(f"Generate a title for the following question: {input_text}, the title should be short and concise.")
|
app.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from agents import investigator, router
|
3 |
+
# from agents_openai_fc import investigator
|
4 |
+
|
5 |
+
st.title('Cyber Hunter!')
|
6 |
+
st.caption("🚀 A streamlit chatbot powered by OpenAI LLM")
|
7 |
+
|
8 |
+
if "messages" not in st.session_state:
|
9 |
+
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
10 |
+
|
11 |
+
for msg in st.session_state.messages:
|
12 |
+
st.chat_message(msg["role"]).write(msg["content"])
|
13 |
+
|
14 |
+
if prompt := st.chat_input():
|
15 |
+
|
16 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
17 |
+
st.chat_message("user").write(prompt)
|
18 |
+
response = router.invoke(prompt)
|
19 |
+
msg = response["output"]
|
20 |
+
print("Title : ", router.generate_title(prompt))
|
21 |
+
|
22 |
+
# print(response['intermediate_steps'])
|
23 |
+
|
24 |
+
# If there's an existing assistant message, update it with the new response
|
25 |
+
if st.session_state.messages[-1]["role"] == "assistant":
|
26 |
+
st.session_state.messages[-1]["content"] = msg
|
27 |
+
else:
|
28 |
+
st.session_state.messages.append({"role": "assistant", "content": msg})
|
29 |
+
|
30 |
+
st.chat_message("assistant").write(msg)
|
31 |
+
|
32 |
+
# Add a button to regenerate response
|
33 |
+
if st.button("Regenerate Response"):
|
34 |
+
st.session_state.messages.pop() # Remove the latest assistant message
|
35 |
+
prompt = st.session_state.messages[-1]["content"] # Retrieve user's last input
|
36 |
+
with st.spinner("Searching..."):
|
37 |
+
response = router.invoke(prompt)
|
38 |
+
msg = response["output"]
|
39 |
+
|
40 |
+
# If there's an existing assistant message, update it with the new response
|
41 |
+
if st.session_state.messages[-1]["role"] == "assistant":
|
42 |
+
st.session_state.messages[-1]["content"] = msg
|
43 |
+
else:
|
44 |
+
st.session_state.messages.append({"role": "assistant", "content": msg})
|
45 |
+
|
46 |
+
st.chat_message("assistant").write(msg)
|
47 |
+
|
48 |
+
if st.button("Clear Chat"):
|
49 |
+
st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}]
|
50 |
+
router.memory.clear()
|
51 |
+
for msg in st.session_state.messages:
|
52 |
+
st.chat_message(msg["role"]).write(msg["content"])
|
53 |
+
st.success("Chat cleared!")
|
template.env
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
OLLAMA_HOST=""
|
2 |
+
MISP_URL=""
|
3 |
+
MISP_KEY=""
|
4 |
+
ENV="dev"
|
tools/coder_tool.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.tools import tool
|
2 |
+
from langchain_community.llms import Ollama
|
3 |
+
|
4 |
+
import os
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
load_dotenv(override=True)
|
7 |
+
|
8 |
+
wrn = Ollama(model="wrn", base_url=os.getenv('OLLAMA_HOST'), num_predict=512, temperature=0.2,
|
9 |
+
system="""
|
10 |
+
You are a coder and you are trying to generate a code snippet based on a given prompt.
|
11 |
+
The code snippet should be in the programming language that's asked for.
|
12 |
+
Don't Wrap the function in a markdown code block. Return it as a text.
|
13 |
+
""")
|
14 |
+
|
15 |
+
|
16 |
+
class CoderTool():
|
17 |
+
@tool("Code Generation Tool")
|
18 |
+
def code_generation_tool(instruction: str, language: str = "python"):
|
19 |
+
"""The code generation tool is a tool that can generate code snippets based on a given instruction.
|
20 |
+
It uses a language model to generate code snippets that are relevant to the given instruction.
|
21 |
+
Parameters:
|
22 |
+
- instruction: The instruction for which the code snippet should be generated.
|
23 |
+
- language: The programming language in which the code snippet should be generated. Default is python.
|
24 |
+
Returns:
|
25 |
+
- A code snippet generated based on the given instruction.
|
26 |
+
"""
|
27 |
+
|
28 |
+
response = wrn.invoke(instruction)
|
29 |
+
response = response.replace("```", "")
|
30 |
+
return f"'{response}'"
|
tools/cve_avd_tool.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from langchain.tools import tool
|
3 |
+
import json
|
4 |
+
|
5 |
+
from datetime import datetime, timedelta
|
6 |
+
|
7 |
+
def get_current_formatted_date():
|
8 |
+
# Get the current date and time
|
9 |
+
now = datetime.now()
|
10 |
+
|
11 |
+
# Format the date and time in the desired format
|
12 |
+
formatted_date = now.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3]
|
13 |
+
return formatted_date
|
14 |
+
|
15 |
+
def get_yesterday_formatted_date():
|
16 |
+
# Get the current date and time
|
17 |
+
now = datetime.now()
|
18 |
+
# Calculate yesterday's date and time
|
19 |
+
yesterday = now - timedelta(days=2)
|
20 |
+
# Format the date and time in the desired format
|
21 |
+
formatted_date = yesterday.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3]
|
22 |
+
return formatted_date
|
23 |
+
|
24 |
+
def format_cve(cve_response):
|
25 |
+
"""
|
26 |
+
Takes a dictionary representing the API response and formats each CVE entry into a structured prompt for LLM.
|
27 |
+
|
28 |
+
Parameters:
|
29 |
+
- cve_response: A dictionary representing the API response containing CVE information.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
- A list of strings, each a formatted prompt for LLM based on the CVE entries in the response.
|
33 |
+
"""
|
34 |
+
formatted_prompts = "Use these CVE Search Results to answer user question:\n\n"
|
35 |
+
|
36 |
+
if len(cve_response['vulnerabilities']) == 0:
|
37 |
+
|
38 |
+
formatted_prompts = "No CVEs found matching the search criteria."
|
39 |
+
return formatted_prompts
|
40 |
+
|
41 |
+
|
42 |
+
for vulnerability in cve_response['vulnerabilities']:
|
43 |
+
cve = vulnerability.get('cve', {})
|
44 |
+
# prompt = f"Explain {cve.get('id', 'N/A')} in simple terms:\n\n"
|
45 |
+
prompt = f"- CVE ID: {cve.get('id', 'N/A')}\n"
|
46 |
+
prompt += f"- Status: {cve.get('vulnStatus', 'Unknown')}\n"
|
47 |
+
|
48 |
+
descriptions = cve.get('descriptions', [])
|
49 |
+
description_text = descriptions[0].get('value', 'No description available.') if descriptions else "No description available."
|
50 |
+
prompt += f"- Description: {description_text}\n"
|
51 |
+
|
52 |
+
if 'metrics' in cve and 'cvssMetricV2' in cve['metrics']:
|
53 |
+
cvss_metrics = cve['metrics']['cvssMetricV2'][0]
|
54 |
+
prompt += f"- CVSS Score: {cvss_metrics.get('cvssData', {}).get('baseScore', 'Not available')} ({cvss_metrics.get('baseSeverity', 'Unknown')})\n"
|
55 |
+
else:
|
56 |
+
prompt += "- CVSS Score: Not available\n"
|
57 |
+
|
58 |
+
configurations = cve.get('configurations', {})
|
59 |
+
for conf in configurations:
|
60 |
+
nodes = conf.get('nodes', [])
|
61 |
+
affected_configs = []
|
62 |
+
for node in nodes:
|
63 |
+
for cpe_match in node.get('cpeMatch', []):
|
64 |
+
if cpe_match.get('vulnerable', False):
|
65 |
+
affected_configs.append(cpe_match.get('criteria', 'Not specified'))
|
66 |
+
prompt += f"- Affected Configurations: {', '.join(affected_configs) if affected_configs else 'Not specified'}\n"
|
67 |
+
|
68 |
+
references = cve.get('references', [])
|
69 |
+
ref_urls = ', '.join([ref.get('url', 'No URL') for ref in references])
|
70 |
+
prompt += f"- References: {ref_urls if references else 'No references available.'}\n"
|
71 |
+
|
72 |
+
|
73 |
+
formatted_prompts += prompt+"\n\n"
|
74 |
+
print(formatted_prompts)
|
75 |
+
|
76 |
+
# formatted_prompts += "\nSummarize the vulnerability, its impact, and any known mitigation strategies."
|
77 |
+
|
78 |
+
return formatted_prompts
|
79 |
+
|
80 |
+
class CVESearchTool():
|
81 |
+
@tool("CVE search Tool")
|
82 |
+
def cvesearch(keyword: str, date: str = None):
|
83 |
+
"""
|
84 |
+
Searches for CVEs based on a keyword or phrase and returns the results in JSON format.
|
85 |
+
|
86 |
+
Parameters:
|
87 |
+
- keyword (str): A word or phrase to search for in the CVE descriptions.
|
88 |
+
- date (str): An optional date to include in the search query.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
- JSON: A list of CVEs matching the keyword search.
|
92 |
+
"""
|
93 |
+
|
94 |
+
if date:
|
95 |
+
keyword = f"{keyword} {date}"
|
96 |
+
# Encode the spaces in the keyword(s) as "%20" for the URL
|
97 |
+
keyword_encoded = keyword.replace(" ", "%20")
|
98 |
+
# keyword_encoded = keyword_encoded.join(" 2023")
|
99 |
+
|
100 |
+
# Construct the URL for the API request
|
101 |
+
url = f"https://services.nvd.nist.gov/rest/json/cves/2.0?keywordSearch={keyword_encoded}&resultsPerPage=5"
|
102 |
+
|
103 |
+
try:
|
104 |
+
# Send the request to the NVD API
|
105 |
+
response = requests.get(url)
|
106 |
+
# Check if the request was successful
|
107 |
+
if response.status_code == 200:
|
108 |
+
# Return the JSON response
|
109 |
+
return format_cve(response.json())
|
110 |
+
else:
|
111 |
+
return {"error": "Failed to fetch data from the NVD API.", "status_code": response.status_code}
|
112 |
+
except Exception as e:
|
113 |
+
return {"error": str(e)}
|
114 |
+
|
115 |
+
@tool("Fetch last 5 CVEs")
|
116 |
+
def fetchcve(keyword: str):
|
117 |
+
"""
|
118 |
+
Fetches the last 5 CVEs based on a keyword or phrase and returns the results in JSON format.
|
119 |
+
Use this exclusively to get the latest CVEs, or for todays CVEs.
|
120 |
+
Parameters:
|
121 |
+
- keyword (str): A word or phrase to search for in the CVE descriptions.
|
122 |
+
- date (str): An optional date to include in the search query.
|
123 |
+
Returns:
|
124 |
+
- JSON: A list of CVEs matching the keyword search.
|
125 |
+
"""
|
126 |
+
# Encode the spaces in the keyword(s) as "%20" for the URL
|
127 |
+
keyword_encoded = keyword.replace(" ", "%20")
|
128 |
+
# keyword_encoded = keyword_encoded.join(" 2023")
|
129 |
+
|
130 |
+
# Construct the URL for the API request
|
131 |
+
url = f"https://services.nvd.nist.gov/rest/json/cves/2.0/?pubStartDate={get_yesterday_formatted_date()}&pubEndDate={get_current_formatted_date()}&keywordSearch={keyword_encoded}&resultsPerPage=3"
|
132 |
+
print(url)
|
133 |
+
try:
|
134 |
+
# Send the request to the NVD API
|
135 |
+
response = requests.get(url)
|
136 |
+
# Check if the request was successful
|
137 |
+
if response.status_code == 200:
|
138 |
+
# Return the JSON response
|
139 |
+
return format_cve(response.json())
|
140 |
+
else:
|
141 |
+
return {"error": "Failed to fetch data from the NVD API.", "status_code": response.status_code}
|
142 |
+
except Exception as e:
|
143 |
+
return {"error": str(e)}
|
tools/cve_search_tool.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
from langchain.tools import tool
|
4 |
+
import json
|
5 |
+
|
6 |
+
class CVESearchTool():
|
7 |
+
@tool("CVE search Tool")
|
8 |
+
def cvesearch(keyword: str):
|
9 |
+
"CVE (Common Vulnerabilities and Exposures) search tool is a useful tool to search for known security vulnerabilities and exposures in various software products, systems, and devices. It helps users to identify specific vulnerabilities by searching through the CVE database, which contains detailed information about vulnerabilities."
|
10 |
+
url = f"https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword={keyword}"
|
11 |
+
|
12 |
+
# Fetch the HTML content
|
13 |
+
response = requests.get(url)
|
14 |
+
if response.status_code == 200:
|
15 |
+
html_content = response.content
|
16 |
+
|
17 |
+
# Parse HTML
|
18 |
+
soup = BeautifulSoup(html_content, 'html.parser')
|
19 |
+
|
20 |
+
# Find CVE records
|
21 |
+
cves = soup.find_all('td', {'valign': 'top', 'nowrap': 'nowrap'})
|
22 |
+
|
23 |
+
# Create a dictionary to store CVEs and descriptions
|
24 |
+
cve_dict = {}
|
25 |
+
|
26 |
+
# Iterate through CVE records
|
27 |
+
for cve in cves:
|
28 |
+
cve_id = cve.text.strip() # Extract CVE ID
|
29 |
+
description = cve.find_next('td').text.strip() # Extract Description
|
30 |
+
cve_dict[cve_id] = description
|
31 |
+
|
32 |
+
# Convert dictionary to JSON string
|
33 |
+
json_string = json.dumps(cve_dict, indent=4)
|
34 |
+
# return json_string
|
35 |
+
return json_string
|
36 |
+
else:
|
37 |
+
print("Failed to fetch the page:", response.status_code)
|
38 |
+
return None
|
tools/elastic_tool.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
from elasticsearch import Elasticsearch
|
4 |
+
from langchain.tools import tool
|
5 |
+
|
6 |
+
es = Elasticsearch(
|
7 |
+
"https://localhost:9200",
|
8 |
+
basic_auth=("elastic","dVJI85*y60R3ZVbECj1w"),
|
9 |
+
ca_certs="/Volumes/macOS/Projects/PFE UM6P/elasticsearch-8.12.1/config/certs/http_ca.crt"
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
class EventSearchTool():
|
14 |
+
@tool("Event search Tool")
|
15 |
+
def search(keyword: str):
|
16 |
+
"""Useful tool to search for an indicator of compromise or an security event
|
17 |
+
Parameters:
|
18 |
+
- keyword: The keyword to search for
|
19 |
+
Returns:
|
20 |
+
- A list of events that match the keyword
|
21 |
+
"""
|
22 |
+
|
23 |
+
|
24 |
+
# if not es.ping():
|
25 |
+
# raise "ElasticNotReachable"
|
26 |
+
|
27 |
+
query = {
|
28 |
+
"match": {"value": {
|
29 |
+
"query": keyword
|
30 |
+
}}
|
31 |
+
}
|
32 |
+
|
33 |
+
# Execute the search query
|
34 |
+
res = es.search(size=5, index="all_events_full", query=query, knn=None, _source=["event_id", "event_title", "event_date", "category", "attribute_tags", "type", "value"])
|
35 |
+
hits = res["hits"]["hits"]
|
36 |
+
events = [x['_source'] for x in hits]
|
37 |
+
|
38 |
+
return events
|
39 |
+
|
40 |
+
|
41 |
+
@tool("Event search by event_id Tool")
|
42 |
+
def get_event_by_id(id:str):
|
43 |
+
"""Useful tool to search for an event by its id, and return the full event details
|
44 |
+
Parameters:
|
45 |
+
- id: The event id to search for
|
46 |
+
Returns:
|
47 |
+
- The full details of the event with the specified id
|
48 |
+
"""
|
49 |
+
|
50 |
+
if not es.ping():
|
51 |
+
raise "ElasticNotReachable"
|
52 |
+
res = es.search(index="all_events_full", query={"match": {"event_id": id}}, _source=["event_id", "event_title", "event_date", "category", "attribute_tags", "type", "value"])
|
53 |
+
hits = res["hits"]["hits"]
|
54 |
+
events = [x['_source'] for x in hits]
|
55 |
+
|
56 |
+
return events
|
57 |
+
|
tools/log_tool.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
## TO_IMPLEMENT_LATER
|
tools/misp_tool.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.tools import tool
|
2 |
+
# import pymisp
|
3 |
+
|
4 |
+
|
5 |
+
from pymisp import PyMISP
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
import os
|
8 |
+
|
9 |
+
load_dotenv(override=True)
|
10 |
+
|
11 |
+
URL = os.getenv('MISP_URL')
|
12 |
+
KEY = os.getenv('MISP_KEY')
|
13 |
+
verify_cert = False
|
14 |
+
|
15 |
+
print(URL, KEY)
|
16 |
+
|
17 |
+
misp = PyMISP(url=URL, key=KEY, ssl=verify_cert)
|
18 |
+
|
19 |
+
class MispTool():
|
20 |
+
@tool("MISP search Tool by keyword")
|
21 |
+
def search(keyword: str):
|
22 |
+
"""Useful tool to search for an indicator of compromise or an security event by keyword
|
23 |
+
Parameters:
|
24 |
+
- keyword: The keyword to search for
|
25 |
+
Returns:
|
26 |
+
- A list of events that match the keyword
|
27 |
+
"""
|
28 |
+
|
29 |
+
events = misp.search(controller='attributes', value=keyword, limit=5, metadata=True, include_event_tags=False, include_context=False, return_format='json', sg_reference_only=True)
|
30 |
+
|
31 |
+
if len(events['Attribute']) == 0:
|
32 |
+
return "No events found matching the search criteria."
|
33 |
+
|
34 |
+
results = """Answer user question using these search results:\n\n"""
|
35 |
+
return results + str(events)
|
36 |
+
|
37 |
+
@tool("MISP search Tool by date")
|
38 |
+
def search_by_date(date_from: str = None, date_to: str = None):
|
39 |
+
"""Useful tool to retrieve events that match a specific date or date range, use this if you know the date of the event
|
40 |
+
Parameters:
|
41 |
+
- date_from: The start date of the event
|
42 |
+
- date_to: The end date of the event
|
43 |
+
Not necessary to provide both dates, you can provide one or the other
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
- A list of events that match the date or date range
|
47 |
+
"""
|
48 |
+
|
49 |
+
events = misp.search(controller='attributes',date_from=date_from, date_to=date_to, limit=5)
|
50 |
+
return events
|
51 |
+
|
52 |
+
@tool("MISP search Tool by event_id")
|
53 |
+
def search_by_event_id(event_id: str | int):
|
54 |
+
"""Useful tool to retrieve events by their ID, use this if you know the ID of the event.
|
55 |
+
Parameters:
|
56 |
+
- event_id: The ID of the event
|
57 |
+
Returns:
|
58 |
+
- A list of events that match the event ID
|
59 |
+
"""
|
60 |
+
|
61 |
+
events = misp.search(controller='attributes', eventid=event_id, limit=1)
|
62 |
+
return events
|
63 |
+
|
tools/mitre_tool.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from stix2 import MemoryStore, Filter
|
3 |
+
from taxii2client.v20 import Server # only specify v20 if your installed version is >= 2.0.0
|
4 |
+
|
5 |
+
from langchain.tools import tool
|
6 |
+
|
7 |
+
|
8 |
+
def get_data_from_branch(domain, branch="master"):
|
9 |
+
"""get the ATT&CK STIX data from MITRE/CTI. Domain should be 'enterprise-attack', 'mobile-attack' or 'ics-attack'. Branch should typically be master."""
|
10 |
+
BASE_URL = f"https://raw.githubusercontent.com/mitre/cti/{branch}/{domain}/{domain}.json"
|
11 |
+
stix_json = requests.get(BASE_URL).json()
|
12 |
+
return MemoryStore(stix_data=stix_json["objects"])
|
13 |
+
|
14 |
+
store = {
|
15 |
+
"enterprise": get_data_from_branch("enterprise-attack"),
|
16 |
+
"mobile": get_data_from_branch("mobile-attack"),
|
17 |
+
"ics": get_data_from_branch("ics-attack")
|
18 |
+
}
|
19 |
+
|
20 |
+
class MitreTool():
|
21 |
+
|
22 |
+
@tool("MITRE Technique search by ID")
|
23 |
+
def get_technique_by_id(domain: str, technique_id: str):
|
24 |
+
"""Get the technique by its ID. Domain should be 'enterprise', 'mobile' or 'ics'
|
25 |
+
Techniques represent 'how' an adversary achieves a tactical goal by performing an action. For example, an adversary may dump credentials to achieve credential access.
|
26 |
+
"""
|
27 |
+
result = store[domain].query([Filter('external_references.external_id', '=', technique_id)])
|
28 |
+
return result if result else "No technique found with that ID"
|
29 |
+
|
30 |
+
@tool("MITRE Technique search by name")
|
31 |
+
def get_technique_by_name(domain: str, technique_name: str):
|
32 |
+
"""Get the technique by its name. Domain should be 'enterprise', 'mobile' or 'ics'
|
33 |
+
Techniques represent 'how' an adversary achieves a tactical goal by performing an action. For example, an adversary may dump credentials to achieve credential access."""
|
34 |
+
result = store[domain].query([Filter('name', 'contains', technique_name), Filter('type', '=', 'attack-pattern')])
|
35 |
+
return result if result else "No technique found with that name"
|
36 |
+
|
37 |
+
@tool("MITRE Malware search by name")
|
38 |
+
def get_malware_by_name(domain: str, malware_name: str):
|
39 |
+
"""Get the malware by its name. Domain should be 'enterprise', 'mobile' or 'ics'
|
40 |
+
Malware represents software used to achieve a tactical goal by performing an action. For example, an adversary may use malware to achieve initial access."""
|
41 |
+
result = store[domain].query([Filter('name', 'contains', malware_name), Filter('type', '=', 'malware')])
|
42 |
+
return result if result else "No malware found with that name"
|
43 |
+
|
44 |
+
@tool("MITRE Technique search by keyword")
|
45 |
+
def get_tactic_by_keyword(domain: str, keyword: str):
|
46 |
+
"""Search for tactics/techniques by a keyword. Domain should be 'enterprise', 'mobile' or 'ics'
|
47 |
+
Tactics represent the "why" of an ATT&CK technique or sub-technique. It is the adversary's tactical goal: the reason for performing an action. For example, an adversary may want to achieve credential access."""
|
48 |
+
result = store[domain].query([Filter('description', 'contains', keyword)], Filter('type', '=', 'attack-pattern'))
|
49 |
+
return result[0] if result else "No tactics/techniques matches the keyword you provided"
|
tools/scraper_tools.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
from langchain.tools import tool
|
4 |
+
|
5 |
+
class ScraperTool():
|
6 |
+
@tool("Scraper Tool")
|
7 |
+
def scrape(url: str):
|
8 |
+
"Useful tool to scrap a website content, use to learn more about a given url."
|
9 |
+
|
10 |
+
headers = {
|
11 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}
|
12 |
+
|
13 |
+
response = requests.get(url, headers=headers)
|
14 |
+
|
15 |
+
# Check if the request was successful
|
16 |
+
if response.status_code == 200:
|
17 |
+
# Parse the HTML content of the page
|
18 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
19 |
+
|
20 |
+
article = soup.find(id='insertArticle')
|
21 |
+
|
22 |
+
if article:
|
23 |
+
# Extract and print the text from the article
|
24 |
+
text = (article.get_text(separator=' ', strip=True))
|
25 |
+
else:
|
26 |
+
print("Article with specified ID not found.")
|
27 |
+
|
28 |
+
return text
|
29 |
+
else:
|
30 |
+
print("Failed to retrieve the webpage")
|
31 |
+
|