VigilCiph3r's picture
Upload 3 files
d2cff7e verified
raw
history blame
5.37 kB
import streamlit as st
import pandas as pd
import numpy as np
import pickle
import matplotlib.pyplot as plt
# Load model
with open('random_forest_classifier (1).pkl', 'rb') as file:
model = pickle.load(file)
# Define the possible outcomes
class_labels = ['phishing', 'benign', 'defacement', 'malware']
def main():
st.set_page_config(page_title="Malicious URL Detection", layout="wide")
st.title("Malicious URL Detection")
st.markdown("""
<style>
.main {background-color: #f5f5f5;}
</style>
""", unsafe_allow_html=True)
st.markdown("### Enter the URL features to predict its category")
# Layout: two columns for input and output
input_col, output_col = st.columns([2, 2])
with input_col:
use_of_ip_address = st.number_input("Use of IP Address (0 or 1)", min_value=0, max_value=1, step=1)
abnormal_url = st.number_input("Abnormal URL (0 or 1)", min_value=0, max_value=1, step=1)
google_index = st.number_input("Google Index (0 or 1)", min_value=0, max_value=1, step=1)
count_www = st.number_input("Count WWW", min_value=0)
count_at = st.number_input("Count @", min_value=0)
count_dir = st.number_input("Count Directory", min_value=0)
count_embed_domian = st.number_input("Count Embedded Domain", min_value=0)
short_url = st.number_input("Short URL (0 or 1)", min_value=0, max_value=1, step=1)
count_https = st.number_input("Count HTTPS", min_value=0)
count_http = st.number_input("Count HTTP", min_value=0)
count_percent = st.number_input("Count %", min_value=0)
count_question = st.number_input("Count ?", min_value=0)
count_dash = st.number_input("Count -", min_value=0)
count_equal = st.number_input("Count =", min_value=0)
url_length = st.number_input("URL Length", min_value=0)
hostname_length = st.number_input("Hostname Length", min_value=0)
sus_url = st.number_input("Suspicious URL (0 or 1)", min_value=0, max_value=1, step=1)
fd_length = st.number_input("FD Length", min_value=0)
tld_length = st.number_input("TLD Length", min_value=0)
count_digits = st.number_input("Count Digits", min_value=0)
count_letters = st.number_input("Count Letters", min_value=0)
with output_col:
if st.button("Predict"):
features = np.array([[use_of_ip_address, abnormal_url, google_index, count_www, count_at,
count_dir, count_embed_domian, short_url, count_https,
count_http, count_percent, count_question, count_dash,
count_equal, url_length, hostname_length, sus_url,
fd_length, tld_length, count_digits, count_letters]])
prediction = model.predict(features)
prediction_probabilities = model.predict_proba(features)[0]
predicted_label = class_labels[prediction[0]]
st.subheader(f"The URL is predicted to be: {predicted_label}")
if predicted_label == 'phishing':
st.error(f"The URL is predicted to be: {predicted_label}")
elif predicted_label == 'malware':
st.warning(f"The URL is predicted to be: {predicted_label}")
elif predicted_label == 'defacement':
st.warning(f"The URL is predicted to be: {predicted_label}")
else:
st.success(f"The URL is predicted to be: {predicted_label}")
# Display prediction probabilities
st.subheader("Prediction Probabilities:")
prob_df = pd.DataFrame(prediction_probabilities, index=class_labels, columns=["Probability"])
st.bar_chart(prob_df)
# Example chart: Feature Importance
st.subheader("Feature Importance:")
feature_importance = model.feature_importances_
features_names = ['use_of_ip_address', 'abnormal_url', 'google_index', 'count_www', 'count@',
'count_dir', 'count_embed_domian', 'short_url', 'count_https',
'count_http', 'count%', 'count?', 'count-', 'count=', 'url_length',
'hostname_length', 'sus_url', 'fd_length', 'tld_length', 'count-digits',
'count-letters']
fig, ax = plt.subplots()
ax.barh(features_names, feature_importance, color='skyblue')
ax.set_xlabel('Importance')
ax.set_title('Feature Importance')
st.pyplot(fig)
# Another example chart: Prediction distribution (simulated)
st.subheader("Prediction Distribution (Simulated):")
prediction_counts = np.random.randint(1, 100, size=4)
prediction_counts[class_labels.index(predicted_label)] += 100 # Simulate the current prediction
fig, ax = plt.subplots()
ax.bar(class_labels, prediction_counts, color=['red', 'green', 'blue', 'orange'])
ax.set_xlabel('Class')
ax.set_ylabel('Count')
ax.set_title('Prediction Distribution')
st.pyplot(fig)
if __name__ == '__main__':
main()