import base64 import re import json import pandas as pd import gradio as gr import pyterrier as pt pt.init() import pyt_splade from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md, EX_Q, EX_D, df2list factory_max = pyt_splade.Splade(agg='max') factory_sum = pyt_splade.Splade(agg='sum') COLAB_NAME = 'pyterrier_splade.ipynb' COLAB_INSTALL = ''' !pip install -q git+https://github.com/naver/splade !pip install -q git+https://github.com/cmacdonald/pyt_splade '''.strip() def generate_vis(df, mode='Document'): if len(df) == 0: return '' result = [] if mode == 'Document': max_score = max(max(t.values()) for t in df['toks']) for row in df.itertuples(index=False): if mode == 'Query': tok_scores = row.query_toks orig_tokens = factory_max.tokenizer.tokenize(row.query) max_score = max(tok_scores.values()) id = row.qid else: tok_scores = row.toks orig_tokens = factory_max.tokenizer.tokenize(row.text) id = row.docno def toks2span(toks): return ' '.join(f'{t}' for t in toks) orig_tokens_set = set(orig_tokens) exp_tokens = [t for t, v in sorted(tok_scores.items(), key=lambda x: (-x[1], x[0])) if t not in orig_tokens_set] result.append(f'''
{mode}: {id}
{toks2span(orig_tokens)}
Expansion Tokens: {toks2span(exp_tokens)}
''') return '\n'.join(result) def predict_query(input, agg): code = f'''import pyt_splade splade = pyt_splade.Splade(agg={agg!r}) query_pipeline = splade.query_encoder() query_pipeline({df2list(input)}) ''' pipeline = { 'max': factory_max, 'sum': factory_sum }[agg].query_encoder() res = pipeline(input) vis = generate_vis(res, mode='Query') res['query_toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['query_toks']] return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis) def predict_doc(input, agg): code = f'''import pyt_splade splade = pyt_splade.Splade(agg={repr(agg)}) doc_pipeline = splade.doc_encoder() doc_pipeline({df2list(input)}) ''' pipeline = { 'max': factory_max, 'sum': factory_sum }[agg].doc_encoder() res = pipeline(input) vis = generate_vis(res, mode='Document') res['toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['toks']] return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis) interface( MarkdownFile('README.md'), MarkdownFile('query.md'), Demo( predict_query, EX_Q, [ gr.Dropdown(choices=['max', 'sum'], value='max', label='Aggregation'), ], scale=2/3 ), MarkdownFile('doc.md'), Demo( predict_doc, EX_D, [ gr.Dropdown(choices=['max', 'sum'], value='max', label='Aggregation'), ], scale=2/3 ), MarkdownFile('wrapup.md'), ).launch(share=False)