import streamlit as st import pandas as pd import numpy as np import pickle import matplotlib.pyplot as plt # Load model with open('random_forest_classifier (1).pkl', 'rb') as file: model = pickle.load(file) # Define the possible outcomes class_labels = ['phishing', 'benign', 'defacement', 'malware'] def main(): st.set_page_config(page_title="Malicious URL Detection", layout="wide") st.title("Malicious URL Detection") st.markdown(""" """, unsafe_allow_html=True) st.markdown("### Enter the URL features to predict its category") # Layout: two columns for input and output input_col, output_col = st.columns([2, 2]) with input_col: use_of_ip_address = st.number_input("Use of IP Address (0 or 1)", min_value=0, max_value=1, step=1) abnormal_url = st.number_input("Abnormal URL (0 or 1)", min_value=0, max_value=1, step=1) google_index = st.number_input("Google Index (0 or 1)", min_value=0, max_value=1, step=1) count_www = st.number_input("Count WWW", min_value=0) count_at = st.number_input("Count @", min_value=0) count_dir = st.number_input("Count Directory", min_value=0) count_embed_domian = st.number_input("Count Embedded Domain", min_value=0) short_url = st.number_input("Short URL (0 or 1)", min_value=0, max_value=1, step=1) count_https = st.number_input("Count HTTPS", min_value=0) count_http = st.number_input("Count HTTP", min_value=0) count_percent = st.number_input("Count %", min_value=0) count_question = st.number_input("Count ?", min_value=0) count_dash = st.number_input("Count -", min_value=0) count_equal = st.number_input("Count =", min_value=0) url_length = st.number_input("URL Length", min_value=0) hostname_length = st.number_input("Hostname Length", min_value=0) sus_url = st.number_input("Suspicious URL (0 or 1)", min_value=0, max_value=1, step=1) fd_length = st.number_input("FD Length", min_value=0) tld_length = st.number_input("TLD Length", min_value=0) count_digits = st.number_input("Count Digits", min_value=0) count_letters = st.number_input("Count Letters", min_value=0) with output_col: if st.button("Predict"): features = np.array([[use_of_ip_address, abnormal_url, google_index, count_www, count_at, count_dir, count_embed_domian, short_url, count_https, count_http, count_percent, count_question, count_dash, count_equal, url_length, hostname_length, sus_url, fd_length, tld_length, count_digits, count_letters]]) prediction = model.predict(features) prediction_probabilities = model.predict_proba(features)[0] predicted_label = class_labels[prediction[0]] st.subheader(f"The URL is predicted to be: {predicted_label}") if predicted_label == 'phishing': st.error(f"The URL is predicted to be: {predicted_label}") elif predicted_label == 'malware': st.warning(f"The URL is predicted to be: {predicted_label}") elif predicted_label == 'defacement': st.warning(f"The URL is predicted to be: {predicted_label}") else: st.success(f"The URL is predicted to be: {predicted_label}") # Display prediction probabilities st.subheader("Prediction Probabilities:") prob_df = pd.DataFrame(prediction_probabilities, index=class_labels, columns=["Probability"]) st.bar_chart(prob_df) # Example chart: Feature Importance st.subheader("Feature Importance:") feature_importance = model.feature_importances_ features_names = ['use_of_ip_address', 'abnormal_url', 'google_index', 'count_www', 'count@', 'count_dir', 'count_embed_domian', 'short_url', 'count_https', 'count_http', 'count%', 'count?', 'count-', 'count=', 'url_length', 'hostname_length', 'sus_url', 'fd_length', 'tld_length', 'count-digits', 'count-letters'] fig, ax = plt.subplots() ax.barh(features_names, feature_importance, color='skyblue') ax.set_xlabel('Importance') ax.set_title('Feature Importance') st.pyplot(fig) # Another example chart: Prediction distribution (simulated) st.subheader("Prediction Distribution (Simulated):") prediction_counts = np.random.randint(1, 100, size=4) prediction_counts[class_labels.index(predicted_label)] += 100 # Simulate the current prediction fig, ax = plt.subplots() ax.bar(class_labels, prediction_counts, color=['red', 'green', 'blue', 'orange']) ax.set_xlabel('Class') ax.set_ylabel('Count') ax.set_title('Prediction Distribution') st.pyplot(fig) if __name__ == '__main__': main()