|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
BASE_PATH = "/fsx/loubna/projects/alignment-handbook/recipes/cosmo2/sft/data" |
|
TEMPERATURE = 0.2 |
|
TOP_P = 0.9 |
|
|
|
CHECKPOINT = "HuggingFaceTB/smollm-350M-instruct-add-basics" |
|
|
|
print(f"💾 Loading the model and tokenizer: {CHECKPOINT}...") |
|
device = "cuda" |
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) |
|
model_s = AutoModelForCausalLM.from_pretrained(CHECKPOINT).to(device) |
|
|
|
print("🧪 Testing single-turn conversations...") |
|
L = [ |
|
|
|
"Discuss the ethical implications of using AI in hiring processes.", |
|
"Give me some tips to improve my time management skills?", |
|
"Write a short dialogue between a customer and a waiter at a restaurant.", |
|
"wassup?", |
|
"Tell me a joke", |
|
"Hi, what are some popular dishes from Japan?", |
|
"What is the capital of Switzerland?", |
|
"What is the capital of France?", |
|
"What's the capital of Portugal?", |
|
"What is the capital of Morocco?", |
|
"How do I make pancakes?", |
|
"Write a poem about Helium", |
|
"Do you think it's important for a company to have a strong company culture? Why or why not?", |
|
"What is your favorite book?", |
|
"What is the most interesting fact you know?", |
|
"What is your favorite movie?", |
|
|
|
|
|
"Can you tell me what is gravity?", |
|
"Who discovered gravity?", |
|
"How does a rainbow form?", |
|
"What are the three states of matter?", |
|
"Why is the sky blue?", |
|
"What is the water cycle?", |
|
"How do magnets work?", |
|
"What is buoyancy?", |
|
"What is the speed of light?", |
|
"What's 2+2?", |
|
"what's the sum of 2 and 2?", |
|
"what's the sum of 2 and 3?", |
|
"What is the term for the process by which plants make their own food?", |
|
"If you have 8 apples and you give away 3, how many apples do you have left?", |
|
|
|
|
|
"How do I define a function in Python?", |
|
"Can you explain what a dictionary is in Python?", |
|
"Write a sort alrogithm in Python", |
|
"Write a fibonacci sequence in Python", |
|
"How do I read a file in Python?", |
|
"How do I make everything uppercase in Python?", |
|
"implement bubble sort in Python", |
|
|
|
|
|
"Write a short story about a time traveler", |
|
"Describe a futuristic city in three sentences", |
|
"Describe a new color that doesn't exist", |
|
"Create a slogan for a time machine company", |
|
"Describe a world where plants can speak", |
|
] |
|
|
|
for i in range(len(L)): |
|
print(f"🔮 {L[i]}") |
|
messages = [{"role": "user", "content": L[i]}] |
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False) |
|
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) |
|
outputs = model_s.generate( |
|
inputs, max_new_tokens=200, top_p=TOP_P, do_sample=True, temperature=TEMPERATURE |
|
) |
|
with open( |
|
f"{BASE_PATH}/{CHECKPOINT.split('/')[-1]}_temp_{TEMPERATURE}_topp{TOP_P}.txt", |
|
"a", |
|
) as f: |
|
f.write("=" * 50 + "\n") |
|
f.write(tokenizer.decode(outputs[0])) |
|
f.write("\n") |
|
|