Spaces:
Sleeping
Sleeping
import os | |
import random | |
import numpy as np | |
import json | |
import re | |
import groq | |
from groq import Groq | |
import gradio as gr | |
# groq | |
client = Groq(api_key=os.environ.get("Groq_Api_Key")) | |
def handle_groq_error(e, model_name): | |
error_data = e.args[0] | |
if isinstance(error_data, str): | |
# Use regex to extract the JSON part of the string | |
json_match = re.search(r'(\{.*\})', error_data) | |
if json_match: | |
json_str = json_match.group(1) | |
# Ensure the JSON string is well-formed | |
json_str = json_str.replace("'", '"') # Replace single quotes with double quotes | |
error_data = json.loads(json_str) | |
if isinstance(e, groq.AuthenticationError): | |
if isinstance(error_data, dict) and 'error' in error_data and 'message' in error_data['error']: | |
error_message = error_data['error']['message'] | |
raise gr.Error(error_message) | |
elif isinstance(e, groq.RateLimitError): | |
if isinstance(error_data, dict) and 'error' in error_data and 'message' in error_data['error']: | |
error_message = error_data['error']['message'] | |
error_message = re.sub(r'org_[a-zA-Z0-9]+', 'org_(censored)', error_message) # censor org | |
raise gr.Error(error_message) | |
else: | |
raise gr.Error(f"Error during Groq API call: {e}") | |
# chat | |
def create_history_messages(history): | |
history_messages = [] | |
for turn in history: | |
history_messages.append({"role": "user", "content": turn[0]}) | |
if turn[1]: # Check if assistant's response is available | |
history_messages.append({"role": "assistant", "content": turn[1]}) | |
return history_messages | |
MAX_SEED = np.iinfo(np.int32).max | |
def generate_initial_story(): | |
initial_prompt = [{"role": "user", "content": "Create a short, spooky, and slightly comical story about being trapped in a haunted house. Describe the initial setting and the first challenge the character faces."}] | |
seed = random.randint(1, MAX_SEED) | |
try: | |
initial_completion = client.chat.completions.create( | |
messages=initial_prompt, | |
model="mixtral-8x7b-32768", | |
temperature=0.7, | |
max_tokens=1000, | |
top_p=0.5, | |
seed=seed | |
) | |
return initial_completion.choices[0].message.content | |
except groq.AuthenticationError as e: | |
handle_groq_error(e, model) | |
except groq.RateLimitError as e: | |
handle_groq_error(e, model) | |
def generate_response(prompt, history): | |
messages = create_history_messages(history) | |
messages.append({"role": "system", "content": "You are an Interactive Story Teller for an Halloween Spooky Escape Room Game! You are meant to generate a random story and let the user type their actions till they manage to find the exit."}) | |
messages.append({"role": "user", "content": prompt}) | |
seed = random.randint(1, MAX_SEED) | |
try: | |
stream = client.chat.completions.create( | |
messages=messages, | |
model="mixtral-8x7b-32768", | |
temperature=0.7, | |
max_tokens=32768, | |
top_p=0.5, | |
seed=seed, | |
stop=None, | |
stream=True, | |
) | |
response = "" | |
for chunk in stream: | |
delta_content = chunk.choices[0].delta.content | |
if delta_content is not None: | |
response += delta_content | |
yield response | |
print(messages) | |
except groq.AuthenticationError as e: | |
handle_groq_error(e, model) | |
except groq.RateLimitError as e: | |
handle_groq_error(e, model) | |
# Set initial chatbot history with AI-generated initial story | |
initial_message = generate_initial_story() | |
chatbot_history = [[None, initial_message]] | |
with gr.Blocks() as interface: | |
gr.Markdown("# 🎃Halloween Escape Room🎃") | |
gr.Markdown("Can you escape the haunted house? Type your actions to interact with the environment <br> Made by <a href='https://linktr.ee/nick088'>Nick088</a>.") | |
chatbot = gr.Chatbot(placeholder="<strong>Send 'start'</strong><br>") | |
gr.ChatInterface( | |
fn=generate_response, | |
type="tuples", | |
chatbot=chatbot | |
) | |
interface.launch(share=True) |