loubnabnl HF staff commited on
Commit
d235430
•
1 Parent(s): 38ff04c

Update test_prompts.py

Browse files
Files changed (1) hide show
  1. test_prompts.py +46 -71
test_prompts.py CHANGED
@@ -4,7 +4,7 @@ BASE_PATH = "/fsx/loubna/projects/alignment-handbook/recipes/cosmo2/sft/data"
4
  TEMPERATURE = 0.2
5
  TOP_P = 0.9
6
 
7
- CHECKPOINT = "loubnabnl/smollm-350M-instruct-add-basics"
8
 
9
  print(f"💾 Loading the model and tokenizer: {CHECKPOINT}...")
10
  device = "cuda"
@@ -13,21 +13,56 @@ model_s = AutoModelForCausalLM.from_pretrained(CHECKPOINT).to(device)
13
 
14
  print("🧪 Testing single-turn conversations...")
15
  L = [
16
- "Hi",
17
- "Hello",
 
 
 
18
  "Tell me a joke",
19
- "Who are you?",
20
- "What's your name?",
 
 
 
21
  "How do I make pancakes?",
 
 
 
 
 
 
 
22
  "Can you tell me what is gravity?",
23
- "What is the capital of Morocco?",
 
 
 
 
 
 
 
24
  "What's 2+2?",
25
- "Hi, what is 2+1?",
26
- "What's 3+5?",
27
- "Write a poem about Helium",
28
- "Hi, what are some popular dishes from Japan?",
29
- ]
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  for i in range(len(L)):
33
  print(f"🔮 {L[i]}")
@@ -44,63 +79,3 @@ for i in range(len(L)):
44
  f.write("=" * 50 + "\n")
45
  f.write(tokenizer.decode(outputs[0]))
46
  f.write("\n")
47
-
48
-
49
- print("🧪 Now testing multi-turn conversations...")
50
- # Multi-turn conversations
51
- messages_1 = [
52
- {"role": "user", "content": "Hi"},
53
- {"role": "assistant", "content": "Hello! How can I help you today?"},
54
- {"role": "user", "content": "What's 2+2?"},
55
- ]
56
- messages_2 = [
57
- {"role": "user", "content": "Hi"},
58
- {"role": "assistant", "content": "Hello! How can I help you today?"},
59
- {"role": "user", "content": "What's 2+2?"},
60
- {"role": "assistant", "content": "4"},
61
- {"role": "user", "content": "Why?"},
62
- ]
63
- messages_3 = [
64
- {"role": "user", "content": "Who are you?"},
65
- {"role": "assistant", "content": "I am an AI assistant. How can I help you today?"},
66
- {"role": "user", "content": "What's your name?"},
67
- ]
68
- messages_4 = [
69
- {"role": "user", "content": "Tell me a joke"},
70
- {"role": "assistant", "content": "Sure! Why did the tomato turn red?"},
71
- {"role": "user", "content": "Why?"},
72
- ]
73
- messages_5 = [
74
- {"role": "user", "content": "Can you tell me what is gravity?"},
75
- {
76
- "role": "assistant",
77
- "content": "Sure! Gravity is a force that attracts objects toward each other. It is what keeps us on the ground and what makes things fall.",
78
- },
79
- {"role": "user", "content": "Who discovered it?"},
80
- ]
81
- messages_6 = [
82
- {"role": "user", "content": "How do I make pancakes?"},
83
- {
84
- "role": "assistant",
85
- "content": "Sure! Here is a simple recipe for pancakes: Ingredients: 1 cup flour, 1 cup milk, 1 egg, 1 tbsp sugar, 1 tsp baking powder, 1/2 tsp salt. Instructions: 1. Mix all the dry ingredients together in a bowl. 2. Add the milk and egg and mix until smooth. 3. Heat a non-stick pan over medium heat. 4. Pour 1/4 cup of batter onto the pan. 5. Cook until bubbles form on the surface, then flip and cook for another minute. 6. Serve with your favorite toppings.",
86
- },
87
- {"role": "user", "content": "What are some popular toppings?"},
88
- ]
89
-
90
- L = [messages_1, messages_2, messages_3, messages_4, messages_5, messages_6]
91
-
92
- for i in range(len(L)):
93
- input_text = tokenizer.apply_chat_template(L[i], tokenize=False)
94
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
95
- outputs = model_s.generate(
96
- inputs, max_new_tokens=200, top_p=TOP_P, do_sample=True, temperature=TEMPERATURE
97
- )
98
- with open(
99
- f"{BASE_PATH}/{CHECKPOINT.split('/')[-1]}_temp_{TEMPERATURE}_topp{TOP_P}_MT.txt",
100
- "a",
101
- ) as f:
102
- f.write("=" * 50 + "\n")
103
- f.write(tokenizer.decode(outputs[0]))
104
- f.write("\n")
105
-
106
- print("🔥 Done!")
 
4
  TEMPERATURE = 0.2
5
  TOP_P = 0.9
6
 
7
+ CHECKPOINT = "HuggingFaceTB/smollm-350M-instruct-add-basics"
8
 
9
  print(f"💾 Loading the model and tokenizer: {CHECKPOINT}...")
10
  device = "cuda"
 
13
 
14
  print("🧪 Testing single-turn conversations...")
15
  L = [
16
+ # Witing and general knowledge prompts
17
+ "Discuss the ethical implications of using AI in hiring processes.",
18
+ "Give me some tips to improve my time management skills?",
19
+ "Write a short dialogue between a customer and a waiter at a restaurant.",
20
+ "wassup?",
21
  "Tell me a joke",
22
+ "Hi, what are some popular dishes from Japan?",
23
+ "What is the capital of Switzerland?",
24
+ "What is the capital of France?",
25
+ "What's the capital of Portugal?",
26
+ "What is the capital of Morocco?",
27
  "How do I make pancakes?",
28
+ "Write a poem about Helium",
29
+ "Do you think it's important for a company to have a strong company culture? Why or why not?",
30
+ "What is your favorite book?",
31
+ "What is the most interesting fact you know?",
32
+ "What is your favorite movie?",
33
+
34
+ # Science prompts
35
  "Can you tell me what is gravity?",
36
+ "Who discovered gravity?",
37
+ "How does a rainbow form?",
38
+ "What are the three states of matter?",
39
+ "Why is the sky blue?",
40
+ "What is the water cycle?",
41
+ "How do magnets work?",
42
+ "What is buoyancy?",
43
+ "What is the speed of light?",
44
  "What's 2+2?",
45
+ "what's the sum of 2 and 2?",
46
+ "what's the sum of 2 and 3?",
47
+ "What is the term for the process by which plants make their own food?",
48
+ "If you have 8 apples and you give away 3, how many apples do you have left?",
 
49
 
50
+ # Python prompts
51
+ "How do I define a function in Python?",
52
+ "Can you explain what a dictionary is in Python?",
53
+ "Write a sort alrogithm in Python",
54
+ "Write a fibonacci sequence in Python",
55
+ "How do I read a file in Python?",
56
+ "How do I make everything uppercase in Python?",
57
+ "implement bubble sort in Python",
58
+
59
+ # Creative prompts
60
+ "Write a short story about a time traveler",
61
+ "Describe a futuristic city in three sentences",
62
+ "Describe a new color that doesn't exist",
63
+ "Create a slogan for a time machine company",
64
+ "Describe a world where plants can speak",
65
+ ]
66
 
67
  for i in range(len(L)):
68
  print(f"🔮 {L[i]}")
 
79
  f.write("=" * 50 + "\n")
80
  f.write(tokenizer.decode(outputs[0]))
81
  f.write("\n")