output_fn
Browse files- code/inference.py +8 -0
code/inference.py
CHANGED
@@ -12,3 +12,11 @@ def predict_fn(data: Union[List[str], str], model):
|
|
12 |
outputs = model(data, padding=False, truncation=True)
|
13 |
embeddings = [np.array(r[0]).mean(axis=0).tolist() for r in outputs]
|
14 |
return embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
outputs = model(data, padding=False, truncation=True)
|
13 |
embeddings = [np.array(r[0]).mean(axis=0).tolist() for r in outputs]
|
14 |
return embeddings
|
15 |
+
|
16 |
+
|
17 |
+
def output_fn(prediction, accept):
|
18 |
+
return json.dumps(
|
19 |
+
obj={
|
20 |
+
"outputs": prediction
|
21 |
+
}
|
22 |
+
)
|