muellerzr's picture
muellerzr HF staff
Merge
a791472
raw
history blame
452 Bytes
<pre>
from accelerate import Accelerator
accelerator = Accelerator()
dataloader, model, optimizer scheduler = accelerator.prepare(
dataloader, model, optimizer, scheduler
)
for batch in dataloader:
optimizer.zero_grad()
inputs, targets = batch
outputs = model(inputs)
loss = loss_function(outputs, targets)
accelerator.backward(loss)
optimizer.step()
scheduler.step()
+accelerator.save_state("checkpoint_dir")</pre>