|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
|
from rich.console import Console
|
|
from rich.markdown import Markdown
|
|
from rich.panel import Panel
|
|
from rich.progress import Progress
|
|
import time
|
|
import os
|
|
import json
|
|
from typing import List, Tuple, Dict, Optional
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
import numpy as np
|
|
from threading import Lock
|
|
import gc
|
|
import logging
|
|
from contextlib import contextmanager
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('chat_system.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
class ConversationTurn:
|
|
"""Represents a single turn in the conversation."""
|
|
role: str
|
|
content: str
|
|
timestamp: float = field(default_factory=time.time)
|
|
token_count: int = 0
|
|
|
|
class TokenManager:
|
|
"""Manages token counting and context window optimization."""
|
|
|
|
def __init__(self, tokenizer, max_context_tokens: int = 4096):
|
|
self.tokenizer = tokenizer
|
|
self.max_context_tokens = max_context_tokens
|
|
self._token_count_cache = {}
|
|
self.cache_lock = Lock()
|
|
|
|
|
|
def count_tokens(self, text: str) -> int:
|
|
"""Count tokens with caching for efficiency."""
|
|
with self.cache_lock:
|
|
if text not in self._token_count_cache:
|
|
tokens = self.tokenizer.encode(text, add_special_tokens=True)
|
|
self._token_count_cache[text] = len(tokens)
|
|
return self._token_count_cache[text]
|
|
|
|
|
|
def optimize_context(self, turns: List[ConversationTurn], max_turns: int = 10) -> List[ConversationTurn]:
|
|
"""Optimize context window while maintaining coherence."""
|
|
total_tokens = 0
|
|
optimized_turns = []
|
|
|
|
|
|
if turns:
|
|
last_turn = turns[-1]
|
|
total_tokens += last_turn.token_count
|
|
optimized_turns.append(last_turn)
|
|
|
|
|
|
for turn in reversed(turns[:-1]):
|
|
if total_tokens + turn.token_count > self.max_context_tokens:
|
|
break
|
|
if len(optimized_turns) >= max_turns:
|
|
break
|
|
total_tokens += turn.token_count
|
|
optimized_turns.insert(0, turn)
|
|
|
|
return optimized_turns
|
|
|
|
|
|
|
|
class ConversationManager:
|
|
"""Manages conversation state and history."""
|
|
|
|
def __init__(self, token_manager: TokenManager):
|
|
self.token_manager = token_manager
|
|
self.turns: List[ConversationTurn] = []
|
|
self.system_prompt = """You are a highly capable AI assistant with expertise in business and technical domains.
|
|
You provide detailed, well-reasoned responses while maintaining a professional tone.
|
|
Focus on delivering accurate, contextual information without repeating previous conversation details."""
|
|
self.system_tokens = token_manager.count_tokens(self.system_prompt)
|
|
|
|
def add_turn(self, role: str, content: str):
|
|
"""Add a new conversation turn with token counting."""
|
|
turn = ConversationTurn(
|
|
role=role,
|
|
content=content,
|
|
token_count=self.token_manager.count_tokens(content)
|
|
)
|
|
self.turns.append(turn)
|
|
|
|
|
|
def get_prompt(self, include_system: bool = True) -> str:
|
|
"""Generate optimized prompt for model input."""
|
|
optimized_turns = self.token_manager.optimize_context(self.turns)
|
|
|
|
components = []
|
|
if include_system:
|
|
components.append(f"System: {self.system_prompt}")
|
|
|
|
|
|
for turn in optimized_turns:
|
|
role_prefix = "Human" if turn.role == "user" else "Assistant"
|
|
components.append(f"{role_prefix}: {turn.content}")
|
|
return "\n\n".join(components)
|
|
|
|
|
|
|
|
class ResponseGenerator:
|
|
"""Handles model inference and response generation."""
|
|
|
|
def __init__(self, model, tokenizer):
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.device = next(model.parameters()).device
|
|
|
|
|
|
|
|
self.base_params = {
|
|
'do_sample': True,
|
|
'top_k': 50,
|
|
'top_p': 0.95,
|
|
'temperature': 0.8,
|
|
'repetition_penalty': 1.1,
|
|
'no_repeat_ngram_size': 4,
|
|
'num_beams': 2,
|
|
'early_stopping': True,
|
|
'length_penalty': 1.2,
|
|
'bad_words_ids': None,
|
|
'min_length': 10,
|
|
'use_cache': True,
|
|
}
|
|
|
|
|
|
@contextmanager
|
|
def inference_mode(self):
|
|
"""Context manager for inference optimization."""
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
try:
|
|
with torch.inference_mode():
|
|
yield
|
|
finally:
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
|
|
def calculate_dynamic_length(self, input_text: str, conversation_length: int) -> int:
|
|
"""Calculate dynamic response length based on input and conversation context."""
|
|
input_tokens = len(self.tokenizer.encode(input_text))
|
|
base_length = max(100, input_tokens * 2)
|
|
|
|
|
|
complexity_factor = min(2.0, 1.0 + (conversation_length / 20))
|
|
dynamic_length = int(base_length * complexity_factor)
|
|
|
|
|
|
return min(max(dynamic_length, 100), 2048)
|
|
|
|
|
|
def generate_response(self, prompt: str, conversation_length: int) -> str:
|
|
"""Generate response with dynamic length and advanced parameters."""
|
|
with self.inference_mode():
|
|
inputs = self.tokenizer(
|
|
prompt,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=4096
|
|
).to(self.device)
|
|
|
|
max_new_tokens = self.calculate_dynamic_length(prompt, conversation_length)
|
|
|
|
generation_params = {
|
|
**self.base_params,
|
|
'max_new_tokens': max_new_tokens,
|
|
'pad_token_id': self.tokenizer.pad_token_id,
|
|
'eos_token_id': self.tokenizer.eos_token_id,
|
|
}
|
|
|
|
outputs = self.model.generate(
|
|
**inputs,
|
|
**generation_params
|
|
)
|
|
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
response_parts = response.split("Assistant:")
|
|
if len(response_parts) > 1:
|
|
response = response_parts[-1].strip()
|
|
return response
|
|
|
|
|
|
class EnterpriseQwenChat:
|
|
"""Main chat interface with enterprise-grade features."""
|
|
|
|
def __init__(self, model_directory: str = "./qwen"):
|
|
self.console = Console()
|
|
self.model_directory = model_directory
|
|
self.setup_components()
|
|
|
|
def setup_components(self):
|
|
"""Initialize components with CUDA support."""
|
|
try:
|
|
self.console.print("Initializing Enterprise Qwen Chat...", style="bold yellow")
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
|
if self.tokenizer.pad_token is None:
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
|
config = AutoConfig.from_pretrained(os.path.join(self.model_directory, "config.json"))
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
self.model_directory,
|
|
config=config,
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
|
device_map="auto" if torch.cuda.is_available() else None,
|
|
)
|
|
|
|
|
|
self.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
self.token_manager = TokenManager(self.tokenizer)
|
|
self.conversation_manager = ConversationManager(self.token_manager)
|
|
self.response_generator = ResponseGenerator(self.model, self.tokenizer)
|
|
|
|
self.console.print("[bold green]System initialized successfully![/bold green]")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Initialization failed: {str(e)}")
|
|
raise
|
|
|
|
def save_conversation(self) -> str:
|
|
"""Save conversation with metadata."""
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filename = f'conversation_{timestamp}.json'
|
|
|
|
conversation_data = {
|
|
'timestamp': timestamp,
|
|
'turns': [
|
|
{
|
|
'role': turn.role,
|
|
'content': turn.content,
|
|
'timestamp': turn.timestamp,
|
|
'token_count': turn.token_count
|
|
}
|
|
for turn in self.conversation_manager.turns
|
|
],
|
|
'metadata': {
|
|
'total_turns': len(self.conversation_manager.turns),
|
|
'total_tokens': sum(turn.token_count for turn in self.conversation_manager.turns)
|
|
}
|
|
}
|
|
|
|
with open(filename, 'w', encoding='utf-8') as f:
|
|
json.dump(conversation_data, f, indent=2)
|
|
|
|
return filename
|
|
|
|
|
|
def run(self):
|
|
"""Run the chat interface with enhanced features."""
|
|
self.console.print(Panel.fit(
|
|
"[bold green]Enterprise Qwen Chat System[/bold green]\n"
|
|
"[italic]Commands:\n"
|
|
"- 'exit' or 'quit': End conversation\n"
|
|
"- 'save': Save conversation\n"
|
|
"- 'clear': Clear conversation history[/italic]"
|
|
))
|
|
|
|
while True:
|
|
try:
|
|
user_input = self.console.input("[bold cyan]You:[/bold cyan] ").strip()
|
|
|
|
if user_input.lower() in ['exit', 'quit']:
|
|
log_file = self.save_conversation()
|
|
self.console.print(f"Conversation saved to: {log_file}", style="bold green")
|
|
break
|
|
|
|
if user_input.lower() == 'save':
|
|
log_file = self.save_conversation()
|
|
self.console.print(f"Conversation saved to: {log_file}", style="bold green")
|
|
continue
|
|
|
|
if user_input.lower() == 'clear':
|
|
self.conversation_manager.turns.clear()
|
|
self.console.print("Conversation history cleared.", style="bold yellow")
|
|
continue
|
|
|
|
|
|
|
|
self.conversation_manager.add_turn("user", user_input)
|
|
|
|
|
|
|
|
with self.console.status("[bold yellow]Generating response...[/bold yellow]"):
|
|
start_time = time.time()
|
|
prompt = self.conversation_manager.get_prompt()
|
|
response = self.response_generator.generate_response(
|
|
prompt,
|
|
len(self.conversation_manager.turns)
|
|
)
|
|
|
|
self.conversation_manager.add_turn("assistant", response)
|
|
|
|
end_time = time.time()
|
|
|
|
self.console.print(Markdown(f"**AI:** {response}"))
|
|
self.console.print(
|
|
f"[italic grey](Generated in {end_time - start_time:.2f} seconds)[/italic grey]\n"
|
|
)
|
|
|
|
except KeyboardInterrupt:
|
|
self.console.print("\nGracefully shutting down...", style="bold yellow")
|
|
self.save_conversation()
|
|
break
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error during chat: {str(e)}")
|
|
self.console.print(
|
|
"[bold red]An error occurred. The conversation has been saved.[/bold red]"
|
|
)
|
|
self.save_conversation()
|
|
break
|
|
|
|
|
|
if __name__ == "__main__":
|
|
chat = EnterpriseQwenChat()
|
|
chat.run()
|
|
|
|
|
|
|
|
|
|
|
|
|