Token Classification
GLiNER
PyTorch
multilingual
bert
Inference Endpoints
Rejebc commited on
Commit
b53706a
1 Parent(s): 1903c5a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +2 -3
handler.py CHANGED
@@ -1,6 +1,5 @@
1
  from typing import Dict, List, Any
2
  from gliner import GLiNER
3
- import os
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
@@ -19,7 +18,7 @@ class EndpointHandler:
19
  """
20
  # Get inputs and labels
21
  inputs = data.get("inputs", "")
22
- labels = data.get("labels", [])
23
 
24
  # Predict entities using GLiNER
25
  entities = self.model.predict_entities(inputs, labels)
@@ -30,7 +29,7 @@ class EndpointHandler:
30
  formatted_entity = {
31
  "word": entity["text"],
32
  "entity_group": entity["label"], # Assuming entity["label"] contains the label
33
- "score": entity.get("score", 1.0) # Assuming a default score of 1.0 if not provided
34
  }
35
  formatted_results.append(formatted_entity)
36
 
 
1
  from typing import Dict, List, Any
2
  from gliner import GLiNER
 
3
 
4
  class EndpointHandler:
5
  def __init__(self, path=""):
 
18
  """
19
  # Get inputs and labels
20
  inputs = data.get("inputs", "")
21
+ labels = ["party", "document title"]
22
 
23
  # Predict entities using GLiNER
24
  entities = self.model.predict_entities(inputs, labels)
 
29
  formatted_entity = {
30
  "word": entity["text"],
31
  "entity_group": entity["label"], # Assuming entity["label"] contains the label
32
+ "score": entity.get("score", 1.0) # Default to 1.0 if score not provided
33
  }
34
  formatted_results.append(formatted_entity)
35