accelerate_examples / code_samples /calculating_metrics
muellerzr's picture
muellerzr HF staff
Merge
a791472
raw
history blame
987 Bytes
<pre>
import evaluate
+from accelerate import Accelerator
+accelerator = Accelerator()
+dataloader, model, optimizer scheduler = accelerator.prepare(
+ dataloader, model, optimizer, scheduler
+)
metric = evaluate.load("accuracy")
for batch in train_dataloader:
optimizer.zero_grad()
inputs, targets = batch
- inputs = inputs.to(device)
- targets = targets.to(device)
outputs = model(inputs)
loss = loss_function(outputs, targets)
loss.backward()
optimizer.step()
scheduler.step()
model.eval()
for batch in eval_dataloader:
inputs, targets = batch
- inputs = inputs.to(device)
- targets = targets.to(device)
with torch.no_grad():
outputs = model(inputs)
predictions = outputs.argmax(dim=-1)
+ predictions, references = accelerator.gather_for_metrics(
+ (predictions, references)
+ )
metric.add_batch(
predictions = predictions,
references = references
)
print(metric.compute())</pre>