model-inference / app.py
htigenai's picture
Update app.py
c8ef1f7 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging
import sys
import gc
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("Starting application...")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
try:
logger.info("Loading tokenizer...")
# Use the base model's tokenizer instead
base_model_id = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
tokenizer = AutoTokenizer.from_pretrained(
base_model_id,
use_fast=True,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
logger.info("Tokenizer loaded successfully")
logger.info("Loading fine-tuned model in 8-bit...")
model_id = "htigenai/finetune_test_2"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
load_in_8bit=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
max_memory={0: "12GB", "cpu": "4GB"}
)
model.eval()
logger.info("Model loaded successfully in 8-bit")
# Clear any residual memory
gc.collect()
torch.cuda.empty_cache()
def generate_text(prompt, max_tokens=100, temperature=0.7):
try:
# Format prompt with chat template
formatted_prompt = f"### Human: {prompt}\n\n### Assistant:"
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256
).to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
top_p=0.95,
repetition_penalty=1.2,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
early_stopping=True,
no_repeat_ngram_size=3,
use_cache=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract assistant's response
if "### Assistant:" in response:
response = response.split("### Assistant:")[-1].strip()
# Clean up
del outputs, inputs
gc.collect()
torch.cuda.empty_cache()
return response
except Exception as e:
logger.error(f"Error during generation: {str(e)}")
return f"Error generating response: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(
lines=3,
placeholder="Enter your prompt here...",
label="Input Prompt",
max_lines=5
),
gr.Slider(
minimum=10,
maximum=100,
value=50,
step=10,
label="Max Tokens"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature"
)
],
outputs=gr.Textbox(
label="Generated Response",
lines=5
),
title="HTIGENAI Reflection Analyzer (8-bit)",
description="Using Llama 3.1 base tokenizer with fine-tuned model. Keep prompts concise for best results.",
examples=[
["What is machine learning?", 50, 0.7],
["Explain quantum computing", 50, 0.7],
],
cache_examples=False
)
# Launch interface
iface.launch(
server_name="0.0.0.0",
share=False,
show_error=True,
enable_queue=True,
max_threads=1
)
except Exception as e:
logger.error(f"Application startup failed: {str(e)}")
raise