Spaces:
Runtime error
Runtime error
import torch | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
def select_model(): | |
while True: | |
print("\nAvailable GPT-2 Models:") | |
print("1. gpt2 (Small)") | |
print("2. gpt2-medium (Medium)") | |
print("3. gpt2-large (Large)") | |
print("4. gpt2-xl (Extra Large)") | |
choice = input("Select a model (1/2/3/4): ") | |
if choice == "1": | |
return "gpt2" | |
elif choice == "2": | |
return "gpt2-medium" | |
elif choice == "3": | |
return "gpt2-large" | |
elif choice == "4": | |
return "gpt2-xl" | |
else: | |
print("Invalid choice. Please select a valid model.") | |
def enhance_prompt(prompt, model_name, max_length=50, num_return_sequences=1): | |
# Load the selected pre-trained GPT-2 model and tokenizer | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
# Tokenize the prompt | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
# Generate text based on the prompt | |
output = model.generate( | |
input_ids, | |
max_length=max_length, | |
num_return_sequences=num_return_sequences, | |
no_repeat_ngram_size=2, # Avoid repetitive phrases | |
top_k=50, # Limit choices to top-k tokens | |
top_p=0.95, # Control diversity with nucleus sampling | |
temperature=0.7 # Adjusts the randomness of the output | |
) | |
# Decode and return the generated text | |
enhanced_prompts = [tokenizer.decode(output_item, skip_special_tokens=True) for output_item in output] | |
return enhanced_prompts | |
if __name__ == "__main__": | |
while True: | |
model_name = select_model() | |
prompt = input("Enter a prompt (or 'exit' to quit): ") | |
if prompt.lower() == "exit": | |
break | |
enhanced_prompts = enhance_prompt(prompt, model_name) | |
print("\nEnhanced Prompts:") | |
for idx, enhanced_prompt in enumerate(enhanced_prompts): | |
print(f"Enhanced Prompt {idx + 1}: {enhanced_prompt}\n") | |