File size: 2,540 Bytes
8f5bae8
919856f
2e39cfb
 
a2aca2e
2e39cfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
PRMs are trained to predict the correctness of each step on the positions of "\n\n" and "\<eos\>".

Usage:

```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "ScalableMath/llemma-7b-prm-metamath-level-1to3-hf"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/llemma_7b")

qa_example = """# Question

Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$

# Solution

To convert from rectangular coordinates to polar coordinates, we use the formulas $r = \sqrt{x^2 + y^2}$ and $\theta = \arctan\left(\frac{y}{x}\right)$.

In this case, $x = 0$ and $y = 3$, so $r = \sqrt{0^2 + 3^2} = 3$ and $\theta = \arctan\left(\frac{3}{0}\right)$.

Since $\frac{3}{0}$ is undefined, we can say that $\theta$ is undefined.
However, we know that $\theta$ is an angle, and since $r > 0$, we can say that $\theta$ is any angle that satisfies $0 \le \theta < 2 \pi$.

Therefore, the polar coordinates of the point $(0,3)$ are $\boxed{(3,\theta)}$, where $0 \le \theta < 2 \pi$.

# Answer

(3,\theta)"""

begin_solution_tokens = tokenizer.encode("\n\n# Solution", add_special_tokens=False)[1:]
scoring_tokens = tokenizer.encode("\n\n", add_special_tokens=False)[1:]
eos_token = tokenizer.eos_token_id

input_ids = tokenizer.encode(qa_example)

begin_solution_flag = False

candidate_positions = []

for start_idx in range(len(input_ids)):
    if tuple(input_ids[start_idx:start_idx+len(begin_solution_tokens)]) == tuple(begin_solution_tokens):
        begin_solution_flag = True

    if begin_solution_flag and tuple(input_ids[start_idx:start_idx+len(scoring_tokens)]) == tuple(scoring_tokens):
        candidate_positions.append(start_idx)

    if input_ids[start_idx] == eos_token:
        candidate_positions.append(start_idx)
        break

# maybe delete the first and the second to last candidate_positions
# because they are "\n\n" after "# Solution" and after "# Answer"
del candidate_positions[0]
del candidate_positions[-2]

input_tensor = torch.tensor([input_ids])
candidate_positions = torch.tensor(candidate_positions)

with torch.no_grad():
    logits = model(input_tensor).logits
    scores =logits.mean(dim=-1)
    step_scores = scores[0][candidate_positions]
    step_probs = torch.sigmoid(step_scores)

print(step_probs)

# tensor([0.7264, 0.8152, 0.7827, 0.4709, 0.5181])
```