loubnabnl's picture
loubnabnl HF staff
Update test_prompts.py
d235430 verified
from transformers import AutoModelForCausalLM, AutoTokenizer
BASE_PATH = "/fsx/loubna/projects/alignment-handbook/recipes/cosmo2/sft/data"
TEMPERATURE = 0.2
TOP_P = 0.9
CHECKPOINT = "HuggingFaceTB/smollm-350M-instruct-add-basics"
print(f"💾 Loading the model and tokenizer: {CHECKPOINT}...")
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
model_s = AutoModelForCausalLM.from_pretrained(CHECKPOINT).to(device)
print("🧪 Testing single-turn conversations...")
L = [
# Witing and general knowledge prompts
"Discuss the ethical implications of using AI in hiring processes.",
"Give me some tips to improve my time management skills?",
"Write a short dialogue between a customer and a waiter at a restaurant.",
"wassup?",
"Tell me a joke",
"Hi, what are some popular dishes from Japan?",
"What is the capital of Switzerland?",
"What is the capital of France?",
"What's the capital of Portugal?",
"What is the capital of Morocco?",
"How do I make pancakes?",
"Write a poem about Helium",
"Do you think it's important for a company to have a strong company culture? Why or why not?",
"What is your favorite book?",
"What is the most interesting fact you know?",
"What is your favorite movie?",
# Science prompts
"Can you tell me what is gravity?",
"Who discovered gravity?",
"How does a rainbow form?",
"What are the three states of matter?",
"Why is the sky blue?",
"What is the water cycle?",
"How do magnets work?",
"What is buoyancy?",
"What is the speed of light?",
"What's 2+2?",
"what's the sum of 2 and 2?",
"what's the sum of 2 and 3?",
"What is the term for the process by which plants make their own food?",
"If you have 8 apples and you give away 3, how many apples do you have left?",
# Python prompts
"How do I define a function in Python?",
"Can you explain what a dictionary is in Python?",
"Write a sort alrogithm in Python",
"Write a fibonacci sequence in Python",
"How do I read a file in Python?",
"How do I make everything uppercase in Python?",
"implement bubble sort in Python",
# Creative prompts
"Write a short story about a time traveler",
"Describe a futuristic city in three sentences",
"Describe a new color that doesn't exist",
"Create a slogan for a time machine company",
"Describe a world where plants can speak",
]
for i in range(len(L)):
print(f"🔮 {L[i]}")
messages = [{"role": "user", "content": L[i]}]
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
outputs = model_s.generate(
inputs, max_new_tokens=200, top_p=TOP_P, do_sample=True, temperature=TEMPERATURE
)
with open(
f"{BASE_PATH}/{CHECKPOINT.split('/')[-1]}_temp_{TEMPERATURE}_topp{TOP_P}.txt",
"a",
) as f:
f.write("=" * 50 + "\n")
f.write(tokenizer.decode(outputs[0]))
f.write("\n")