muellerzr's picture
muellerzr HF staff
Merge
a791472
raw
history blame
503 Bytes
<pre>
+from accelerate import Accelerator
+accelerator = Accelerator()
+dataloader, model, optimizer scheduler = accelerator.prepare(
+ dataloader, model, optimizer, scheduler
+)
for batch in dataloader:
optimizer.zero_grad()
inputs, targets = batch
- inputs = inputs.to(device)
- targets = targets.to(device)
outputs = model(inputs)
loss = loss_function(outputs, targets)
- loss.backward()
+ accelerator.backward(loss)
optimizer.step()
scheduler.step()</pre>