Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import logging | |
import sys | |
import gc | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
logger.info("Starting application...") | |
logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
if torch.cuda.is_available(): | |
logger.info(f"GPU: {torch.cuda.get_device_name(0)}") | |
try: | |
logger.info("Loading tokenizer...") | |
# Use the base model's tokenizer instead | |
base_model_id = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" | |
tokenizer = AutoTokenizer.from_pretrained( | |
base_model_id, | |
use_fast=True, | |
trust_remote_code=True | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
logger.info("Tokenizer loaded successfully") | |
logger.info("Loading fine-tuned model in 8-bit...") | |
model_id = "htigenai/finetune_test_2" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
load_in_8bit=True, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
max_memory={0: "12GB", "cpu": "4GB"} | |
) | |
model.eval() | |
logger.info("Model loaded successfully in 8-bit") | |
# Clear any residual memory | |
gc.collect() | |
torch.cuda.empty_cache() | |
def generate_text(prompt, max_tokens=100, temperature=0.7): | |
try: | |
# Format prompt with chat template | |
formatted_prompt = f"### Human: {prompt}\n\n### Assistant:" | |
inputs = tokenizer( | |
formatted_prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=256 | |
).to(model.device) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.95, | |
repetition_penalty=1.2, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
early_stopping=True, | |
no_repeat_ngram_size=3, | |
use_cache=True | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract assistant's response | |
if "### Assistant:" in response: | |
response = response.split("### Assistant:")[-1].strip() | |
# Clean up | |
del outputs, inputs | |
gc.collect() | |
torch.cuda.empty_cache() | |
return response | |
except Exception as e: | |
logger.error(f"Error during generation: {str(e)}") | |
return f"Error generating response: {str(e)}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox( | |
lines=3, | |
placeholder="Enter your prompt here...", | |
label="Input Prompt", | |
max_lines=5 | |
), | |
gr.Slider( | |
minimum=10, | |
maximum=100, | |
value=50, | |
step=10, | |
label="Max Tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
], | |
outputs=gr.Textbox( | |
label="Generated Response", | |
lines=5 | |
), | |
title="HTIGENAI Reflection Analyzer (8-bit)", | |
description="Using Llama 3.1 base tokenizer with fine-tuned model. Keep prompts concise for best results.", | |
examples=[ | |
["What is machine learning?", 50, 0.7], | |
["Explain quantum computing", 50, 0.7], | |
], | |
cache_examples=False | |
) | |
# Launch interface | |
iface.launch( | |
server_name="0.0.0.0", | |
share=False, | |
show_error=True, | |
enable_queue=True, | |
max_threads=1 | |
) | |
except Exception as e: | |
logger.error(f"Application startup failed: {str(e)}") | |
raise |