import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM import datetime # Set page configuration st.set_page_config( page_title="Qwen2.5-Coder Chat", page_icon="💬", layout="wide" ) # Initialize session state if 'messages' not in st.session_state: st.session_state.messages = [] @st.cache_resource def load_model_and_tokenizer(): try: # Display loading message with st.spinner("🔄 Loading model and tokenizer... This might take a few minutes..."): model_name = "Qwen/Qwen2.5-Coder-3B-Instruct" # Load tokenizer first tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) # Determine device and display info device = "cuda" if torch.cuda.is_available() else "cpu" st.info(f"💻 Using device: {device}") # Load model with appropriate settings if device == "cuda": model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, # Use float16 for GPU device_map="auto", trust_remote_code=True ).eval() # Set to evaluation mode else: model = AutoModelForCausalLM.from_pretrained( model_name, device_map={"": device}, trust_remote_code=True, low_cpu_mem_usage=True ).eval() # Set to evaluation mode return tokenizer, model except Exception as e: st.error(f"❌ Error loading model: {str(e)}") raise e def generate_response(prompt, model, tokenizer, max_new_tokens=512, temperature=0.7, top_p=0.9): """Generate response from the model with better error handling""" try: # Tokenize input inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate response with progress bar with torch.no_grad(), st.spinner("🤔 Thinking..."): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, repetition_penalty=1.1, no_repeat_ngram_size=3 ) # Decode and return response response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response[len(prompt):].strip() except torch.cuda.OutOfMemoryError: st.error("💾 GPU memory exceeded. Try reducing the maximum length or clearing the conversation.") return None except Exception as e: st.error(f"❌ Error generating response: {str(e)}") return None # Main UI st.title("💬 Qwen2.5-Coder Chat") # Sidebar settings with st.sidebar: st.header("⚙️ Settings") # Model settings max_length = st.slider( "Maximum Length 📏", min_value=64, max_value=2048, value=512, step=64 ) temperature = st.slider( "Temperature 🌡️", min_value=0.1, max_value=2.0, value=0.7, step=0.1 ) top_p = st.slider( "Top P 📊", min_value=0.1, max_value=1.0, value=0.9, step=0.1 ) # Clear conversation button if st.button("🗑️ Clear Conversation"): st.session_state.messages = [] st.rerun() # Load model try: tokenizer, model = load_model_and_tokenizer() except Exception as e: st.error("❌ Failed to load model. Please check the logs and refresh the page.") st.stop() # Display conversation history for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(f"{message['content']}\n\n_{message['timestamp']}_") # Chat input if prompt := st.chat_input("💭 Ask me anything about coding..."): # Add user message timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") st.session_state.messages.append({ "role": "user", "content": prompt, "timestamp": timestamp }) # Display user message with st.chat_message("user"): st.markdown(f"{prompt}\n\n_{timestamp}_") # Generate and display response with st.chat_message("assistant"): # Prepare conversation context (limit to last 3 messages to prevent context overflow) conversation = "\n".join( f"{'Human' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}" for msg in st.session_state.messages[-3:] ) + "\nAssistant:" response = generate_response( conversation, model, tokenizer, max_new_tokens=max_length, temperature=temperature, top_p=top_p ) if response: timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") st.markdown(f"{response}\n\n_{timestamp}_") # Add response to chat history st.session_state.messages.append({ "role": "assistant", "content": response, "timestamp": timestamp }) else: st.error("❌ Failed to generate response. Please try again with different settings.")