File size: 713 Bytes
bd64901
 
 
67181a9
 
 
 
 
 
 
a04445e
67181a9
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
---
license: mit
---
## Usage

```python
import torch
from informer_models import InformerConfig, InformerForSequenceClassification

model = InformerForSequenceClassification.from_pretrained("BrachioLab/supernova-classification")

model.to(device)
model.eval()
y_true = []
y_pred = []
for i, batch in enumerate(test_dataloader):
    print(f"processing batch {i}")
    batch = {k: v.to(device) for k, v in batch.items() if k != "objid"}
    with torch.no_grad():
        outputs = model(**batch)
    y_true.extend(batch['labels'].cpu().numpy())
    y_pred.extend(torch.argmax(outputs.logits, dim=2).squeeze().cpu().numpy())
print(f"accuracy: {sum([1 for i, j in zip(y_true, y_pred) if i == j]) / len(y_true)}")
```