macavaney commited on
Commit
f1889e7
1 Parent(s): bbda641

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -48,14 +48,13 @@ def generate_vis(df, mode='Document'):
48
  return '\n'.join(result)
49
 
50
  def predict_query(input, agg):
51
- code = f'''import pandas as pd
52
- import pyt_splade
53
 
54
  splade = pyt_splade.Splade(agg={agg!r})
55
 
56
  query_pipeline = splade.query_encoder()
57
 
58
- query_pipeline({df2code(input)})
59
  '''
60
  pipeline = {
61
  'max': factory_max,
@@ -63,17 +62,17 @@ query_pipeline({df2code(input)})
63
  }[agg].query_encoder()
64
  res = pipeline(input)
65
  vis = generate_vis(res, mode='Query')
 
66
  return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis)
67
 
68
  def predict_doc(input, agg):
69
- code = f'''import pandas as pd
70
- import pyt_splade
71
 
72
  splade = pyt_splade.Splade(agg={repr(agg)})
73
 
74
  doc_pipeline = splade.doc_encoder()
75
 
76
- doc_pipeline({df2code(input)})
77
  '''
78
  pipeline = {
79
  'max': factory_max,
 
48
  return '\n'.join(result)
49
 
50
  def predict_query(input, agg):
51
+ code = f'''import pyt_splade
 
52
 
53
  splade = pyt_splade.Splade(agg={agg!r})
54
 
55
  query_pipeline = splade.query_encoder()
56
 
57
+ query_pipeline({df2list(input)})
58
  '''
59
  pipeline = {
60
  'max': factory_max,
 
62
  }[agg].query_encoder()
63
  res = pipeline(input)
64
  vis = generate_vis(res, mode='Query')
65
+ res['query_toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['query_toks']]
66
  return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis)
67
 
68
  def predict_doc(input, agg):
69
+ code = f'''import pyt_splade
 
70
 
71
  splade = pyt_splade.Splade(agg={repr(agg)})
72
 
73
  doc_pipeline = splade.doc_encoder()
74
 
75
+ doc_pipeline({df2list(input)})
76
  '''
77
  pipeline = {
78
  'max': factory_max,