Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
## | |
<pre> | |
import evaluate | |
+from accelerate import Accelerator | |
+accelerator = Accelerator() | |
+train_dataloader, eval_dataloader, model, optimizer, scheduler = ( | |
+ accelerator.prepare( | |
+ train_dataloader, eval_dataloader, | |
+ model, optimizer, scheduler | |
+ ) | |
+) | |
metric = evaluate.load("accuracy") | |
for batch in train_dataloader: | |
optimizer.zero_grad() | |
inputs, targets = batch | |
- inputs = inputs.to(device) | |
- targets = targets.to(device) | |
outputs = model(inputs) | |
loss = loss_function(outputs, targets) | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
model.eval() | |
for batch in eval_dataloader: | |
inputs, targets = batch | |
- inputs = inputs.to(device) | |
- targets = targets.to(device) | |
with torch.no_grad(): | |
outputs = model(inputs) | |
predictions = outputs.argmax(dim=-1) | |
+ predictions, references = accelerator.gather_for_metrics( | |
+ (predictions, references) | |
+ ) | |
metric.add_batch( | |
predictions = predictions, | |
references = references | |
) | |
print(metric.compute())</pre> | |
## | |
When calculating metrics on a validation set, you can use the `Accelerator.gather_for_metrics` | |
method to gather the predictions and references from all devices and then calculate the metric on the gathered values. | |
This will also *automatically* drop the padded values from the gathered tensors that were added to ensure | |
that all tensors have the same length. This ensures that the metric is calculated on the correct values. | |
## | |
To learn more checkout the related documentation: | |
- <a href="https://huggingface.co/docs/accelerate/en/quicktour#distributed-evaluation" target="_blank">Quicktour - Calculating metrics</a> | |
- <a href="https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.gather_for_metrics" target="_blank">API reference</a> | |
- <a href="https://github.com/huggingface/accelerate/blob/main/examples/by_feature/multi_process_metrics.py" target="_blank">Example script</a> |