ofai-it2v / app.py
fantaxy's picture
Update app.py
8b62ce7 verified
raw
history blame
2.43 kB
import gradio as gr
from huggingface_hub import InferenceClient
import os
import logging
from gradio_client import Client # ์ด๋ฏธ์ง€ ์ƒ์„ฑ API ํด๋ผ์ด์–ธํŠธ
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
# ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face API ํ† ํฐ์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
hf_client = InferenceClient("CohereForAI/c4ai-command-r-plus", token=os.getenv("HF_TOKEN"))
# ์ด๋ฏธ์ง€ ์ƒ์„ฑ API ํด๋ผ์ด์–ธํŠธ ์„ค์ •
client = Client("http://211.233.58.202:7960/")
def respond(message, history, system_message, max_tokens, temperature, top_p):
# ์ดˆ๊ธฐ ์„ค์ • ๋ฐ ๋ณ€์ˆ˜ ์ •์˜
system_prefix = "System: ์ž…๋ ฅ์–ด์˜ ์–ธ์–ด์— ๋”ฐ๋ผ ๋™์ผํ•œ ์–ธ์–ด๋กœ ๋‹ต๋ณ€ํ•˜๋ผ."
full_system_message = f"{system_prefix}{system_message}"
messages = [{"role": "system", "content": f"{system_prefix} {system_message}"}]
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
# ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์š”์ฒญ
try:
result = client.predict(
prompt=message,
seed=123,
randomize_seed=False,
width=1024,
height=576,
guidance_scale=5,
num_inference_steps=28,
api_name="/infer_t2i"
)
if 'url' in result:
return result['url']
else:
logging.error("Image generation failed with error: %s", result.get('error', 'Unknown error'))
return "Failed to generate image."
except Exception as e:
logging.error("Error during API request: %s", str(e))
return f"An error occurred: {str(e)}"
theme = "Nymbo/Nymbo_Theme"
css = """
footer {
visibility: hidden;
}
"""
# Gradio ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Textbox(value="You are an AI assistant.", label="System Prompt"),
gr.Slider(minimum=1, maximum=2000, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
),
],
theme=theme,
css=css
)
if __name__ == "__main__":
demo.launch()