File size: 1,916 Bytes
cf6f740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from transformers import TokenClassificationPipeline

class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
  def __init__(self,**kwargs):
    import numpy
    super().__init__(**kwargs)
    x=self.model.config.label2id
    y=[k for k in x if not k.startswith("I-")]
    self.transition=numpy.full((len(x),len(x)),numpy.nan)
    for k,v in x.items():
      for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
        self.transition[v,x[j]]=0
  def check_model_type(self,supported_models):
    pass
  def postprocess(self,model_outputs,**kwargs):
    import numpy
    if "logits" not in model_outputs:
      return self.postprocess(model_outputs[0],**kwargs)
    m=model_outputs["logits"][0].numpy()
    e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
    z=e/e.sum(axis=-1,keepdims=True)
    for i in range(m.shape[0]-1,0,-1):
      m[i-1]+=numpy.nanmax(m[i]+self.transition,axis=1)
    k=[numpy.nanargmax(m[0]+self.transition[0])]
    for i in range(1,m.shape[0]):
      k.append(numpy.nanargmax(m[i]+self.transition[k[-1]]))
    w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(model_outputs["offset_mapping"][0].tolist(),k)) if s<e]
    if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
      for i,t in reversed(list(enumerate(w))):
        p=t.pop("entity")
        if p.startswith("I-"):
          w[i-1]["score"]=min(w[i-1]["score"],t["score"])
          w[i-1]["end"]=w.pop(i)["end"]
        elif p.startswith("B-"):
          t["entity_group"]=p[2:]
        else:
          t["entity_group"]=p
    s=model_outputs["sentence"]
    for i,t in enumerate(w):
      if t["end"]<len(s):
        if s[t["end"]] in {"\u0f0b","\u0f0c"}:
          if len(w)-i==1 or t["end"]<w[i+1]["start"]:
            t["end"]+=1
      t["text"]=s[t["start"]:t["end"]]
    return w