GeorgiosIoannouCoder's picture
Upload 3 files
8a9d49c verified
raw
history blame
14 kB
#############################################################################################################################
# Filename : app.py
# Description: A Streamlit application to showcase the importance of Responsible AI in LLMs.
# Author : Georgios Ioannou
#
# TODO: Add code for Google Gemma 7b and 7b-it.
# TODO: Write code documentation.
# Copyright © 2024 by Georgios Ioannou
#############################################################################################################################
# Import libraries.
import os # Load environment variable(s).
import requests # Send HTTP GET request to Hugging Face models for inference.
import streamlit as st # Build the GUI of the application.
import streamlit.components.v1 as components
from dataclasses import dataclass
from dotenv import find_dotenv, load_dotenv # Read local .env file.
from langchain.callbacks import get_openai_callback
from langchain.chains import ConversationChain
from langchain.llms import OpenAI
from policies import complex_policy, simple_policy
from transformers import pipeline # Access to Hugging Face models.
from typing import Literal
#############################################################################################################################
# Load environment variable(s).
load_dotenv(find_dotenv()) # Read local .env file.
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
#############################################################################################################################
@dataclass
class Message:
"""Class for keeping track of a chat message."""
origin: Literal["human", "ai"]
message: str
#############################################################################################################################
# Remeber (Keep history) user's message in each rendering of the application.
def initialize_session_state():
if "history" not in st.session_state:
st.session_state.history = []
if "token_count" not in st.session_state:
st.session_state.token_count = 0
if "policy" not in st.session_state:
st.session_state.policy = "No Policy"
if "is_gemma_model" not in st.session_state:
st.session_state.is_gemma_model = False
if "selected_ai_icon" not in st.session_state:
st.session_state.selected_ai_icon = "a1.png"
if "selected_user_icon" not in st.session_state:
st.session_state.selected_user_icon = "man.png"
#############################################################################################################################
def on_click_callback():
with get_openai_callback() as cb:
human_prompt = st.session_state.human_prompt
# Policy + User Query Text.
role = "User" # This can be an Agent.
task = f"""
Task: Check if there is unsafe content in
'{role}' messages in conversations
according our safety policy with the below categories.
"""
output_format = f"""
Provide your safety assessment for '{role}' in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories.
"""
query = human_prompt
conversation = f"""
<BEGIN CONVERSATION>
User: {query}
<END CONVERSATION>
"""
if st.session_state.policy == "Simple Policy":
prompt = f"""
{task}
{simple_policy}
{conversation}
{output_format}
"""
elif st.session_state.policy == "Complex Policy":
prompt = f"""
{task}
{complex_policy}
{conversation}
{output_format}
"""
elif st.session_state.policy == "No Policy":
prompt = human_prompt
# Getting the llm response for safety check 1.
# "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b"
if st.session_state.is_gemma_model:
pass
else:
llm_response_safety_check_1 = st.session_state.conversation.run(prompt)
st.session_state.history.append(Message("human", human_prompt))
st.session_state.token_count += cb.total_tokens
# Checking if response is safe. Safety Check 1. Checking what goes in (user input).
if (
"unsafe" in llm_response_safety_check_1.lower()
): # If respone is unsafe return unsafe.
st.session_state.history.append(Message("ai", llm_response_safety_check_1))
return
else: # If respone is safe answer the question.
if st.session_state.is_gemma_model:
pass
else:
conversation_chain = ConversationChain(
llm=OpenAI(
temperature=0.2,
openai_api_key=OPENAI_API_KEY,
model_name=st.session_state.model,
),
)
llm_response = conversation_chain.run(human_prompt)
# st.session_state.history.append(Message("ai", llm_response))
st.session_state.token_count += cb.total_tokens
# Policy + LLM Response.
query = llm_response
conversation = f"""
<BEGIN CONVERSATION>
User: {query}
<END CONVERSATION>
"""
if st.session_state.policy == "Simple Policy":
prompt = f"""
{task}
{simple_policy}
{conversation}
{output_format}
"""
elif st.session_state.policy == "Complex Policy":
prompt = f"""
{task}
{complex_policy}
{conversation}
{output_format}
"""
elif st.session_state.policy == "No Policy":
prompt = llm_response
# Getting the llm response for safety check 2.
# "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b"
if st.session_state.is_gemma_model:
pass
else:
llm_response_safety_check_2 = st.session_state.conversation.run(prompt)
st.session_state.token_count += cb.total_tokens
# Checking if response is safe. Safety Check 2. Checking what goes out (llm output).
if (
"unsafe" in llm_response_safety_check_2.lower()
): # If respone is unsafe return.
st.session_state.history.append(
Message(
"ai",
"THIS FROM THE AUTHOR OF THE CODE: LLM WANTED TO RESPOND UNSAFELY!",
)
)
else:
st.session_state.history.append(Message("ai", llm_response))
#############################################################################################################################
# Function to apply local CSS.
def local_css(file_name):
with open(file_name) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
#############################################################################################################################
# Main function to create the Streamlit web application.
def main():
# try:
initialize_session_state()
# Page title and favicon.
st.set_page_config(page_title="Responsible AI", page_icon="⚖️")
# Load CSS.
local_css("./static/styles/styles.css")
# Title.
title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">
Responsible AI</h1>"""
st.markdown(title, unsafe_allow_html=True)
# Subtitle 1.
title = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">
Showcase the importance of Responsible AI in LLMs</h3>"""
st.markdown(title, unsafe_allow_html=True)
# Subtitle 2.
title = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem">
CUNY Tech Prep Tutorial 6</h2>"""
st.markdown(title, unsafe_allow_html=True)
# Image.
image = "./static/ctp.png"
left_co, cent_co, last_co = st.columns(3)
with cent_co:
st.image(image=image)
# Sidebar dropdown menu for Models.
models = [
"gpt-4-turbo",
"gpt-4",
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
"gemma-7b",
"gemma-7b-it",
]
selected_model = st.sidebar.selectbox("Select Model:", models)
st.sidebar.write(f"Current Model: {selected_model}")
if selected_model == "gpt-4-turbo":
st.session_state.model = "gpt-4-turbo"
elif selected_model == "gpt-4":
st.session_state.model = "gpt-4"
elif selected_model == "gpt-3.5-turbo":
st.session_state.model = "gpt-3.5-turbo"
elif selected_model == "gpt-3.5-turbo-instruct":
st.session_state.model = "gpt-3.5-turbo-instruct"
elif selected_model == "gemma-7b":
st.session_state.model = "gemma-7b"
elif selected_model == "gemma-7b-it":
st.session_state.model = "gemma-7b-it"
if "gpt" in st.session_state.model:
st.session_state.conversation = ConversationChain(
llm=OpenAI(
temperature=0.2,
openai_api_key=OPENAI_API_KEY,
model_name=st.session_state.model,
),
)
elif "gemma" in st.session_state.model:
# Load model from Hugging Face.
st.session_state.is_gemma_model = True
pass
# Sidebar dropdown menu for Policies.
policies = ["No Policy", "Complex Policy", "Simple Policy"]
selected_policy = st.sidebar.selectbox("Select Policy:", policies)
st.sidebar.write(f"Current Policy: {selected_policy}")
if selected_policy == "No Policy":
st.session_state.policy = "No Policy"
elif selected_policy == "Complex Policy":
st.session_state.policy = "Complex Policy"
elif selected_policy == "Simple Policy":
st.session_state.policy = "Simple Policy"
# Sidebar dropdown menu for AI Icons.
ai_icons = ["AI 1", "AI 2"]
selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons)
st.sidebar.write(f"Current AI Icon: {selected_ai_icon}")
if selected_ai_icon == "AI 1":
st.session_state.selected_ai_icon = "ai1.png"
elif selected_ai_icon == "AI 2":
st.session_state.selected_ai_icon = "ai2.png"
# Sidebar dropdown menu for User Icons.
user_icons = ["Man", "Woman"]
selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons)
st.sidebar.write(f"Current User Icon: {selected_user_icon}")
if selected_user_icon == "Man":
st.session_state.selected_user_icon = "man.png"
elif selected_user_icon == "Woman":
st.session_state.selected_user_icon = "woman.png"
# Placeholder for the chat messages.
chat_placeholder = st.container()
# Placeholder for the user input.
prompt_placeholder = st.form("chat-form")
token_placeholder = st.empty()
with chat_placeholder:
for chat in st.session_state.history:
div = f"""
<div class="chat-row
{'' if chat.origin == 'ai' else 'row-reverse'}">
<img class="chat-icon" src="app/static/{
st.session_state.selected_ai_icon if chat.origin == 'ai'
else st.session_state.selected_user_icon}"
width=32 height=32>
<div class="chat-bubble
{'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
&#8203;{chat.message}
</div>
</div>
"""
st.markdown(div, unsafe_allow_html=True)
for _ in range(3):
st.markdown("")
# User prompt.
with prompt_placeholder:
st.markdown("**Chat**")
cols = st.columns((6, 1))
# Large text input in the left column.
cols[0].text_input(
"Chat",
placeholder="What is your question?",
label_visibility="collapsed",
key="human_prompt",
)
# Red button in the right column.
cols[1].form_submit_button(
"Submit",
type="primary",
on_click=on_click_callback,
)
token_placeholder.caption(
f"""
Used {st.session_state.token_count} tokens \n
"""
)
# GitHub repository of author.
st.markdown(
f"""
<p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our
<a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b>
</p>
""",
unsafe_allow_html=True,
)
# Use the Enter key in the keyborad to click on the Submit button.
components.html(
"""
<script>
const streamlitDoc = window.parent.document;
const buttons = Array.from(
streamlitDoc.querySelectorAll('.stButton > button')
);
const submitButton = buttons.find(
el => el.innerText === 'Submit'
);
streamlitDoc.addEventListener('keydown', function(e) {
switch (e.key) {
case 'Enter':
submitButton.click();
break;
}
});
</script>
""",
height=0,
width=0,
)
#############################################################################################################################
if __name__ == "__main__":
main()