|
PRMs are trained to predict the correctness of each step on the positions of "\n\n" and "\<eos\>". |
|
|
|
Usage: |
|
|
|
```python |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
model_name = "ScalableMath/llemma-7b-prm-metamath-level-1to3-hf" |
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/llemma_7b") |
|
|
|
qa_example = """# Question |
|
|
|
Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$ |
|
|
|
# Solution |
|
|
|
To convert from rectangular coordinates to polar coordinates, we use the formulas $r = \sqrt{x^2 + y^2}$ and $\theta = \arctan\left(\frac{y}{x}\right)$. |
|
|
|
In this case, $x = 0$ and $y = 3$, so $r = \sqrt{0^2 + 3^2} = 3$ and $\theta = \arctan\left(\frac{3}{0}\right)$. |
|
|
|
Since $\frac{3}{0}$ is undefined, we can say that $\theta$ is undefined. |
|
However, we know that $\theta$ is an angle, and since $r > 0$, we can say that $\theta$ is any angle that satisfies $0 \le \theta < 2 \pi$. |
|
|
|
Therefore, the polar coordinates of the point $(0,3)$ are $\boxed{(3,\theta)}$, where $0 \le \theta < 2 \pi$. |
|
|
|
# Answer |
|
|
|
(3,\theta)""" |
|
|
|
begin_solution_tokens = tokenizer.encode("\n\n# Solution", add_special_tokens=False)[1:] |
|
scoring_tokens = tokenizer.encode("\n\n", add_special_tokens=False)[1:] |
|
eos_token = tokenizer.eos_token_id |
|
|
|
input_ids = tokenizer.encode(qa_example) |
|
|
|
begin_solution_flag = False |
|
|
|
candidate_positions = [] |
|
|
|
for start_idx in range(len(input_ids)): |
|
if tuple(input_ids[start_idx:start_idx+len(begin_solution_tokens)]) == tuple(begin_solution_tokens): |
|
begin_solution_flag = True |
|
|
|
if begin_solution_flag and tuple(input_ids[start_idx:start_idx+len(scoring_tokens)]) == tuple(scoring_tokens): |
|
candidate_positions.append(start_idx) |
|
|
|
if input_ids[start_idx] == eos_token: |
|
candidate_positions.append(start_idx) |
|
break |
|
|
|
# maybe delete the first and the second to last candidate_positions |
|
# because they are "\n\n" after "# Solution" and after "# Answer" |
|
del candidate_positions[0] |
|
del candidate_positions[-2] |
|
|
|
input_tensor = torch.tensor([input_ids]) |
|
candidate_positions = torch.tensor(candidate_positions) |
|
|
|
with torch.no_grad(): |
|
logits = model(input_tensor).logits |
|
scores =logits.mean(dim=-1) |
|
step_scores = scores[0][candidate_positions] |
|
step_probs = torch.sigmoid(step_scores) |
|
|
|
print(step_probs) |
|
|
|
# tensor([0.7264, 0.8152, 0.7827, 0.4709, 0.5181]) |
|
``` |