import base64
import re
import json
import pandas as pd
import gradio as gr
import pyterrier as pt
pt.init()
import pyt_splade
from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md, EX_Q, EX_D, df2list
factory_max = pyt_splade.Splade(agg='max')
factory_sum = pyt_splade.Splade(agg='sum')
COLAB_NAME = 'pyterrier_splade.ipynb'
COLAB_INSTALL = '''
!pip install -q git+https://github.com/naver/splade
!pip install -q git+https://github.com/cmacdonald/pyt_splade
'''.strip()
def generate_vis(df, mode='Document'):
if len(df) == 0:
return ''
result = []
if mode == 'Document':
max_score = max(max(t.values()) for t in df['toks'])
for row in df.itertuples(index=False):
if mode == 'Query':
tok_scores = row.query_toks
orig_tokens = factory_max.tokenizer.tokenize(row.query)
max_score = max(tok_scores.values())
id = row.qid
else:
tok_scores = row.toks
orig_tokens = factory_max.tokenizer.tokenize(row.text)
id = row.docno
def toks2span(toks):
return ' '.join(f'{t}' for t in toks)
orig_tokens_set = set(orig_tokens)
exp_tokens = [t for t, v in sorted(tok_scores.items(), key=lambda x: (-x[1], x[0])) if t not in orig_tokens_set]
result.append(f'''
{mode}: {id}
{toks2span(orig_tokens)}
Expansion Tokens: {toks2span(exp_tokens)}
''')
return '\n'.join(result)
def predict_query(input, agg):
code = f'''import pyt_splade
splade = pyt_splade.Splade(agg={agg!r})
query_pipeline = splade.query_encoder()
query_pipeline({df2list(input)})
'''
pipeline = {
'max': factory_max,
'sum': factory_sum
}[agg].query_encoder()
res = pipeline(input)
vis = generate_vis(res, mode='Query')
res['query_toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['query_toks']]
return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis)
def predict_doc(input, agg):
code = f'''import pyt_splade
splade = pyt_splade.Splade(agg={repr(agg)})
doc_pipeline = splade.doc_encoder()
doc_pipeline({df2list(input)})
'''
pipeline = {
'max': factory_max,
'sum': factory_sum
}[agg].doc_encoder()
res = pipeline(input)
vis = generate_vis(res, mode='Document')
res['toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['toks']]
return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis)
interface(
MarkdownFile('README.md'),
MarkdownFile('query.md'),
Demo(
predict_query,
EX_Q,
[
gr.Dropdown(choices=['max', 'sum'], value='max', label='Aggregation'),
],
scale=2/3
),
MarkdownFile('doc.md'),
Demo(
predict_doc,
EX_D,
[
gr.Dropdown(choices=['max', 'sum'], value='max', label='Aggregation'),
],
scale=2/3
),
MarkdownFile('wrapup.md'),
).launch(share=False)