asofter's picture
* upgrade of the version
05bf37a
import logging
import time
from datetime import timedelta
from typing import Dict, List
import streamlit as st
from llm_guard.input_scanners.anonymize import default_entity_types
from llm_guard.input_scanners.code import SUPPORTED_LANGUAGES as SUPPORTED_CODE_LANGUAGES
from llm_guard.output_scanners import get_scanner_by_name
from llm_guard.output_scanners.bias import MatchType as BiasMatchType
from llm_guard.output_scanners.deanonymize import MatchingStrategy as DeanonymizeMatchingStrategy
from llm_guard.output_scanners.gibberish import MatchType as GibberishMatchType
from llm_guard.output_scanners.language import MatchType as LanguageMatchType
from llm_guard.output_scanners.toxicity import MatchType as ToxicityMatchType
from llm_guard.vault import Vault
from streamlit_tags import st_tags
logger = logging.getLogger("llm-guard-playground")
def init_settings() -> (List, Dict):
all_scanners = [
"BanCode",
"BanCompetitors",
"BanSubstrings",
"BanTopics",
"Bias",
"Code",
"Deanonymize",
"JSON",
"Language",
"LanguageSame",
"MaliciousURLs",
"NoRefusal",
"NoRefusalLight" "ReadingTime",
"FactualConsistency",
"Gibberish",
"Regex",
"Relevance",
"Sensitive",
"Sentiment",
"Toxicity",
"URLReachability",
]
st_enabled_scanners = st.sidebar.multiselect(
"Select scanners",
options=all_scanners,
default=all_scanners,
help="The list can be found here: https://llm-guard.com/output_scanners/bias/",
)
settings = {}
if "BanCode" in st_enabled_scanners:
st_bc_expander = st.sidebar.expander(
"Ban Code",
expanded=False,
)
with st_bc_expander:
st_bc_threshold = st.slider(
label="Threshold",
value=0.95,
min_value=0.0,
max_value=1.0,
step=0.05,
key="ban_code_threshold",
)
settings["BanCode"] = {"threshold": st_bc_threshold}
if "BanCompetitors" in st_enabled_scanners:
st_bc_expander = st.sidebar.expander(
"Ban Competitors",
expanded=False,
)
with st_bc_expander:
st_bc_competitors = st_tags(
label="List of competitors",
text="Type and press enter",
value=["openai", "anthropic", "deepmind", "google"],
suggestions=[],
maxtags=30,
key="bc_competitors",
)
st_bc_threshold = st.slider(
label="Threshold",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="ban_competitors_threshold",
)
settings["BanCompetitors"] = {
"competitors": st_bc_competitors,
"threshold": st_bc_threshold,
}
if "BanSubstrings" in st_enabled_scanners:
st_bs_expander = st.sidebar.expander(
"Ban Substrings",
expanded=False,
)
with st_bs_expander:
st_bs_substrings = st.text_area(
"Enter substrings to ban (one per line)",
value="test\nhello\nworld\n",
height=200,
).split("\n")
st_bs_match_type = st.selectbox(
"Match type", ["str", "word"], index=0, key="bs_match_type"
)
st_bs_case_sensitive = st.checkbox(
"Case sensitive", value=False, key="bs_case_sensitive"
)
st_bs_redact = st.checkbox("Redact", value=False, key="bs_redact")
st_bs_contains_all = st.checkbox("Contains all", value=False, key="bs_contains_all")
settings["BanSubstrings"] = {
"substrings": st_bs_substrings,
"match_type": st_bs_match_type,
"case_sensitive": st_bs_case_sensitive,
"redact": st_bs_redact,
"contains_all": st_bs_contains_all,
}
if "BanTopics" in st_enabled_scanners:
st_bt_expander = st.sidebar.expander(
"Ban Topics",
expanded=False,
)
with st_bt_expander:
st_bt_topics = st_tags(
label="List of topics",
text="Type and press enter",
value=["violence"],
suggestions=[],
maxtags=30,
key="bt_topics",
)
st_bt_threshold = st.slider(
label="Threshold",
value=0.6,
min_value=0.0,
max_value=1.0,
step=0.05,
key="ban_topics_threshold",
)
settings["BanTopics"] = {"topics": st_bt_topics, "threshold": st_bt_threshold}
if "Bias" in st_enabled_scanners:
st_bias_expander = st.sidebar.expander(
"Bias",
expanded=False,
)
with st_bias_expander:
st_bias_threshold = st.slider(
label="Threshold",
value=0.75,
min_value=0.0,
max_value=1.0,
step=0.05,
key="bias_threshold",
)
st_bias_match_type = st.selectbox(
"Match type", [e.value for e in BiasMatchType], index=1, key="bias_match_type"
)
settings["Bias"] = {
"threshold": st_bias_threshold,
"match_type": BiasMatchType(st_bias_match_type),
}
if "Code" in st_enabled_scanners:
st_cd_expander = st.sidebar.expander(
"Code",
expanded=False,
)
with st_cd_expander:
st_cd_languages = st.multiselect(
"Programming languages",
options=SUPPORTED_CODE_LANGUAGES,
default=["Python"],
)
st_cd_is_blocked = st.checkbox("Is blocked", value=False, key="cd_is_blocked")
settings["Code"] = {
"languages": st_cd_languages,
"is_blocked": st_cd_is_blocked,
}
if "Deanonymize" in st_enabled_scanners:
st_de_expander = st.sidebar.expander(
"Deanonymize",
expanded=False,
)
with st_de_expander:
st_de_matching_strategy = st.selectbox(
"Matching strategy", [e.value for e in DeanonymizeMatchingStrategy], index=0
)
settings["Deanonymize"] = {
"matching_strategy": DeanonymizeMatchingStrategy(st_de_matching_strategy)
}
if "JSON" in st_enabled_scanners:
st_json_expander = st.sidebar.expander(
"JSON",
expanded=False,
)
with st_json_expander:
st_json_required_elements = st.slider(
label="Required elements",
value=0,
min_value=0,
max_value=10,
step=1,
key="json_required_elements",
help="The minimum number of JSON elements that should be present",
)
st_json_repair = st.checkbox(
"Repair", value=False, help="Attempt to repair the JSON", key="json_repair"
)
settings["JSON"] = {
"required_elements": st_json_required_elements,
"repair": st_json_repair,
}
if "Language" in st_enabled_scanners:
st_lan_expander = st.sidebar.expander(
"Language",
expanded=False,
)
with st_lan_expander:
st_lan_valid_language = st.multiselect(
"Languages",
[
"ar",
"bg",
"de",
"el",
"en",
"es",
"fr",
"hi",
"it",
"ja",
"nl",
"pl",
"pt",
"ru",
"sw",
"th",
"tr",
"ur",
"vi",
"zh",
],
default=["en"],
)
st_lan_match_type = st.selectbox(
"Match type", [e.value for e in LanguageMatchType], index=1, key="lan_match_type"
)
settings["Language"] = {
"valid_languages": st_lan_valid_language,
"match_type": LanguageMatchType(st_lan_match_type),
}
if "MaliciousURLs" in st_enabled_scanners:
st_murls_expander = st.sidebar.expander(
"Malicious URLs",
expanded=False,
)
with st_murls_expander:
st_murls_threshold = st.slider(
label="Threshold",
value=0.75,
min_value=0.0,
max_value=1.0,
step=0.05,
key="murls_threshold",
)
settings["MaliciousURLs"] = {"threshold": st_murls_threshold}
if "NoRefusal" in st_enabled_scanners:
st_no_ref_expander = st.sidebar.expander(
"No refusal",
expanded=False,
)
with st_no_ref_expander:
st_no_ref_threshold = st.slider(
label="Threshold",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="no_ref_threshold",
)
settings["NoRefusal"] = {"threshold": st_no_ref_threshold}
if "NoRefusalLight" in st_enabled_scanners:
settings["NoRefusalLight"] = {}
if "ReadingTime" in st_enabled_scanners:
st_rt_expander = st.sidebar.expander(
"Reading Time",
expanded=False,
)
with st_rt_expander:
st_rt_max_reading_time = st.slider(
label="Max reading time (in minutes)",
value=5,
min_value=0,
max_value=3600,
step=5,
key="rt_max_reading_time",
)
st_rt_truncate = st.checkbox(
"Truncate",
value=False,
help="Truncate the text to the max reading time",
key="rt_truncate",
)
settings["ReadingTime"] = {"max_time": st_rt_max_reading_time, "truncate": st_rt_truncate}
if "FactualConsistency" in st_enabled_scanners:
st_fc_expander = st.sidebar.expander(
"FactualConsistency",
expanded=False,
)
with st_fc_expander:
st_fc_minimum_score = st.slider(
label="Minimum score",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="fc_threshold",
)
settings["FactualConsistency"] = {"minimum_score": st_fc_minimum_score}
if "Regex" in st_enabled_scanners:
st_regex_expander = st.sidebar.expander(
"Regex",
expanded=False,
)
with st_regex_expander:
st_regex_patterns = st.text_area(
"Enter patterns to ban (one per line)",
value="Bearer [A-Za-z0-9-._~+/]+",
height=200,
).split("\n")
st_regex_is_blocked = st.checkbox("Is blocked", value=True, key="regex_is_blocked")
st_regex_redact = st.checkbox(
"Redact",
value=False,
help="Replace the matched bad patterns with [REDACTED]",
key="regex_redact",
)
settings["Regex"] = {
"patterns": st_regex_patterns,
"is_blocked": st_regex_is_blocked,
"redact": st_regex_redact,
}
if "Relevance" in st_enabled_scanners:
st_rele_expander = st.sidebar.expander(
"Relevance",
expanded=False,
)
with st_rele_expander:
st_rele_threshold = st.slider(
label="Threshold",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="rele_threshold",
)
settings["Relevance"] = {"threshold": st_rele_threshold}
if "Sensitive" in st_enabled_scanners:
st_sens_expander = st.sidebar.expander(
"Sensitive",
expanded=False,
)
with st_sens_expander:
st_sens_entity_types = st_tags(
label="Sensitive entities",
text="Type and press enter",
value=default_entity_types,
suggestions=default_entity_types
+ ["DATE_TIME", "NRP", "LOCATION", "MEDICAL_LICENSE", "US_PASSPORT"],
maxtags=30,
key="sensitive_entity_types",
)
st.caption(
"Check all supported entities: https://llm-guard.com/input_scanners/anonymize/"
)
st_sens_redact = st.checkbox("Redact", value=False, key="sens_redact")
st_sens_threshold = st.slider(
label="Threshold",
value=0.0,
min_value=0.0,
max_value=1.0,
step=0.1,
key="sens_threshold",
)
settings["Sensitive"] = {
"entity_types": st_sens_entity_types,
"redact": st_sens_redact,
"threshold": st_sens_threshold,
}
if "Sentiment" in st_enabled_scanners:
st_sent_expander = st.sidebar.expander(
"Sentiment",
expanded=False,
)
with st_sent_expander:
st_sent_threshold = st.slider(
label="Threshold",
value=-0.5,
min_value=-1.0,
max_value=1.0,
step=0.1,
key="sentiment_threshold",
help="Negative values are negative sentiment, positive values are positive sentiment",
)
settings["Sentiment"] = {"threshold": st_sent_threshold}
if "Toxicity" in st_enabled_scanners:
st_tox_expander = st.sidebar.expander(
"Toxicity",
expanded=False,
)
with st_tox_expander:
st_tox_threshold = st.slider(
label="Threshold",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="toxicity_threshold",
)
st_tox_match_type = st.selectbox(
"Match type",
[e.value for e in ToxicityMatchType],
index=1,
key="toxicity_match_type",
)
settings["Toxicity"] = {
"threshold": st_tox_threshold,
"match_type": ToxicityMatchType(st_tox_match_type),
}
if "URLReachability" in st_enabled_scanners:
st_url_expander = st.sidebar.expander(
"URL Reachability",
expanded=False,
)
if st_url_expander:
settings["URLReachability"] = {}
if "Gibberish" in st_enabled_scanners:
st_gib_expander = st.sidebar.expander(
"Gibberish",
expanded=False,
)
with st_gib_expander:
st_gib_threshold = st.slider(
label="Threshold",
value=0.7,
min_value=0.0,
max_value=1.0,
step=0.1,
key="gib_threshold",
)
st_gib_match_type = st.selectbox(
"Match type", [e.value for e in GibberishMatchType], index=1, key="gib_match_type"
)
settings["Gibberish"] = {"match_type": st_gib_match_type, "threshold": st_gib_threshold}
return st_enabled_scanners, settings
def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
logger.debug(f"Initializing {scanner_name} scanner")
if scanner_name == "Deanonymize":
settings["vault"] = vault
if scanner_name in [
"BanCode",
"BanTopics",
"Bias",
"Code",
"Gibberish",
"Language",
"LanguageSame",
"MaliciousURLs",
"NoRefusal",
"FactualConsistency",
"Relevance",
"Sensitive",
"Toxicity",
]:
settings["use_onnx"] = True
return get_scanner_by_name(scanner_name, settings)
def scan(
vault: Vault,
enabled_scanners: List[str],
settings: Dict,
prompt: str,
text: str,
fail_fast: bool = False,
) -> (str, List[Dict[str, any]]):
sanitized_output = text
results = []
status_text = "Scanning prompt..."
if fail_fast:
status_text = "Scanning prompt (fail fast mode)..."
with st.status(status_text, expanded=True) as status:
for scanner_name in enabled_scanners:
st.write(f"{scanner_name} scanner...")
scanner = get_scanner(
scanner_name, vault, settings[scanner_name] if scanner_name in settings else {}
)
start_time = time.monotonic()
sanitized_output, is_valid, risk_score = scanner.scan(prompt, sanitized_output)
end_time = time.monotonic()
results.append(
{
"scanner": scanner_name,
"is_valid": is_valid,
"risk_score": risk_score,
"took_sec": round(timedelta(seconds=end_time - start_time).total_seconds(), 2),
}
)
if fail_fast and not is_valid:
break
status.update(label="Scanning complete", state="complete", expanded=False)
return sanitized_output, results