File size: 987 Bytes
0ba62b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a791472
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
<pre>
import evaluate
+from accelerate import Accelerator
+accelerator = Accelerator()
+dataloader, model, optimizer scheduler = accelerator.prepare(
+        dataloader, model, optimizer, scheduler
+)
metric = evaluate.load("accuracy")
for batch in train_dataloader:
    optimizer.zero_grad()
    inputs, targets = batch
-    inputs = inputs.to(device)
-    targets = targets.to(device)
    outputs = model(inputs)
    loss = loss_function(outputs, targets)
    loss.backward()
    optimizer.step()
    scheduler.step()

model.eval()
for batch in eval_dataloader:
    inputs, targets = batch
-    inputs = inputs.to(device)
-    targets = targets.to(device)
    with torch.no_grad():
        outputs = model(inputs)
    predictions = outputs.argmax(dim=-1)
+    predictions, references = accelerator.gather_for_metrics(
+        (predictions, references)
+    )
    metric.add_batch(
        predictions = predictions,
        references = references
    )
print(metric.compute())</pre>