Spaces:
Sleeping
Sleeping
from dotenv import load_dotenv, find_dotenv | |
import os | |
import requests | |
# Please ensure you have a .env file available with 'HUGGINGFACEHUB_API_TOKEN' | |
load_dotenv(find_dotenv()) | |
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"] | |
with open("data/llama_guard-unsafe_files.txt", "r") as file: | |
unsafe_categories = file.read() | |
def moderation_prompt_for_chat(chat): | |
# For simplicity, we assume the chat dictionary correctly alternates "user" and "assistant" roles | |
# Feel free to include error checking if needed | |
conversation = [turn["content"] for turn in chat] | |
is_agent = len(conversation) % 2 == 0 | |
role = "Agent" if len(conversation) % 2 == 0 else "User" | |
prompt = f"""[INST] Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories. | |
<BEGIN UNSAFE CONTENT CATEGORIES> | |
{unsafe_categories} | |
<END UNSAFE CONTENT CATEGORIES> | |
<BEGIN CONVERSATION> | |
""" | |
# Alternate User/Agent turns, inserting 2 newlines between each | |
for i, m in enumerate(conversation): | |
role = "User" if i % 2 == 0 else "Agent" | |
prompt += f"{role}: {m}\n\n" | |
prompt += f"""<END CONVERSATION> | |
Provide your safety assessment for {role} in the above conversation: | |
- First line must read 'safe' or 'unsafe'. | |
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""" | |
return prompt | |
def query(payload): | |
API_URL = "https://okoknht2arqo574k.us-east-1.aws.endpoints.huggingface.cloud" | |
bearer_txt = f"Bearer {HUGGINGFACEHUB_API_TOKEN}" | |
headers = { | |
"Accept": "application/json", | |
"Authorization": bearer_txt, | |
"Content-Type": "application/json", | |
} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
def moderate_chat(chat): | |
prompt = moderation_prompt_for_chat(chat) | |
output = query( | |
{ | |
"inputs": prompt, | |
"parameters": { | |
"top_k": 1, | |
"top_p": 0.2, | |
"temperature": 0.1, | |
"max_new_tokens": 512, | |
}, | |
} | |
) | |
return output | |