Fix example usage code
#4
by
xzuyn
- opened
README.md
CHANGED
@@ -34,22 +34,26 @@ Eurus-RM-7B is trained on a mixture of [UltraInteract](https://huggingface.co/da
|
|
34 |
from transformers import AutoTokenizer, AutoModel
|
35 |
import torch
|
36 |
|
|
|
37 |
def test(model_path):
|
38 |
-
dataset = [
|
39 |
-
|
|
|
|
|
|
|
40 |
]
|
41 |
|
42 |
-
|
43 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
44 |
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
45 |
|
46 |
with torch.no_grad():
|
47 |
for example in dataset:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
53 |
|
54 |
test("openbmb/Eurus-RM-7b")
|
55 |
# Output: 47.4404296875
|
|
|
34 |
from transformers import AutoTokenizer, AutoModel
|
35 |
import torch
|
36 |
|
37 |
+
|
38 |
def test(model_path):
|
39 |
+
dataset = [ # cases in webgpt; we use the same template as Mistral-Instruct-v0.2
|
40 |
+
{
|
41 |
+
"chosen": "[INST] Sural relates to which part of the body? [\INST] The sural region is the muscular swelling of the back of the leg below the knee, formed chiefly by the bellies of the gastrocnemius and soleus muscles [1,2].",
|
42 |
+
"rejected": "[INST] Sural relates to which part of the body? [\INST] The Sural nerve runs down the side of the leg near the small saphenous vein, then passes forward below the lateral malleolus and continues on the outside of the foot as the lateral dorsal cutaneous nerve, which then communicates with the intermediate dorsal cutaneous nerve, which branches off to the side of the foot. [1]",
|
43 |
+
}
|
44 |
]
|
45 |
|
|
|
46 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
47 |
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
48 |
|
49 |
with torch.no_grad():
|
50 |
for example in dataset:
|
51 |
+
inputs = tokenizer(example["chosen"], return_tensors="pt")
|
52 |
+
chosen_reward = model(**inputs).item()
|
53 |
+
inputs = tokenizer(example["rejected"], return_tensors="pt")
|
54 |
+
rejected_reward = model(**inputs).item()
|
55 |
+
print(chosen_reward - rejected_reward)
|
56 |
+
|
57 |
|
58 |
test("openbmb/Eurus-RM-7b")
|
59 |
# Output: 47.4404296875
|