|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
import pickle |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
with open('random_forest_classifier (1).pkl', 'rb') as file: |
|
model = pickle.load(file) |
|
|
|
|
|
|
|
|
|
class_labels = ['phishing', 'benign', 'defacement', 'malware'] |
|
|
|
def main(): |
|
st.set_page_config(page_title="Malicious URL Detection", layout="wide") |
|
|
|
st.title("Malicious URL Detection") |
|
|
|
st.markdown(""" |
|
<style> |
|
.main {background-color: #f5f5f5;} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
st.markdown("### Enter the URL features to predict its category") |
|
|
|
|
|
input_col, output_col = st.columns([2, 2]) |
|
|
|
with input_col: |
|
use_of_ip_address = st.number_input("Use of IP Address (0 or 1)", min_value=0, max_value=1, step=1) |
|
abnormal_url = st.number_input("Abnormal URL (0 or 1)", min_value=0, max_value=1, step=1) |
|
google_index = st.number_input("Google Index (0 or 1)", min_value=0, max_value=1, step=1) |
|
count_www = st.number_input("Count WWW", min_value=0) |
|
count_at = st.number_input("Count @", min_value=0) |
|
count_dir = st.number_input("Count Directory", min_value=0) |
|
count_embed_domian = st.number_input("Count Embedded Domain", min_value=0) |
|
short_url = st.number_input("Short URL (0 or 1)", min_value=0, max_value=1, step=1) |
|
count_https = st.number_input("Count HTTPS", min_value=0) |
|
count_http = st.number_input("Count HTTP", min_value=0) |
|
count_percent = st.number_input("Count %", min_value=0) |
|
count_question = st.number_input("Count ?", min_value=0) |
|
count_dash = st.number_input("Count -", min_value=0) |
|
count_equal = st.number_input("Count =", min_value=0) |
|
url_length = st.number_input("URL Length", min_value=0) |
|
hostname_length = st.number_input("Hostname Length", min_value=0) |
|
sus_url = st.number_input("Suspicious URL (0 or 1)", min_value=0, max_value=1, step=1) |
|
fd_length = st.number_input("FD Length", min_value=0) |
|
tld_length = st.number_input("TLD Length", min_value=0) |
|
count_digits = st.number_input("Count Digits", min_value=0) |
|
count_letters = st.number_input("Count Letters", min_value=0) |
|
|
|
with output_col: |
|
if st.button("Predict"): |
|
features = np.array([[use_of_ip_address, abnormal_url, google_index, count_www, count_at, |
|
count_dir, count_embed_domian, short_url, count_https, |
|
count_http, count_percent, count_question, count_dash, |
|
count_equal, url_length, hostname_length, sus_url, |
|
fd_length, tld_length, count_digits, count_letters]]) |
|
|
|
prediction = model.predict(features) |
|
prediction_probabilities = model.predict_proba(features)[0] |
|
|
|
predicted_label = class_labels[prediction[0]] |
|
|
|
st.subheader(f"The URL is predicted to be: {predicted_label}") |
|
|
|
if predicted_label == 'phishing': |
|
st.error(f"The URL is predicted to be: {predicted_label}") |
|
elif predicted_label == 'malware': |
|
st.warning(f"The URL is predicted to be: {predicted_label}") |
|
elif predicted_label == 'defacement': |
|
st.warning(f"The URL is predicted to be: {predicted_label}") |
|
else: |
|
st.success(f"The URL is predicted to be: {predicted_label}") |
|
|
|
|
|
st.subheader("Prediction Probabilities:") |
|
prob_df = pd.DataFrame(prediction_probabilities, index=class_labels, columns=["Probability"]) |
|
st.bar_chart(prob_df) |
|
|
|
|
|
st.subheader("Feature Importance:") |
|
feature_importance = model.feature_importances_ |
|
features_names = ['use_of_ip_address', 'abnormal_url', 'google_index', 'count_www', 'count@', |
|
'count_dir', 'count_embed_domian', 'short_url', 'count_https', |
|
'count_http', 'count%', 'count?', 'count-', 'count=', 'url_length', |
|
'hostname_length', 'sus_url', 'fd_length', 'tld_length', 'count-digits', |
|
'count-letters'] |
|
|
|
fig, ax = plt.subplots() |
|
ax.barh(features_names, feature_importance, color='skyblue') |
|
ax.set_xlabel('Importance') |
|
ax.set_title('Feature Importance') |
|
st.pyplot(fig) |
|
|
|
|
|
st.subheader("Prediction Distribution (Simulated):") |
|
prediction_counts = np.random.randint(1, 100, size=4) |
|
prediction_counts[class_labels.index(predicted_label)] += 100 |
|
|
|
fig, ax = plt.subplots() |
|
ax.bar(class_labels, prediction_counts, color=['red', 'green', 'blue', 'orange']) |
|
ax.set_xlabel('Class') |
|
ax.set_ylabel('Count') |
|
ax.set_title('Prediction Distribution') |
|
st.pyplot(fig) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|