nitinbhayana's picture
Update app.py
4ea6b57 verified
raw
history blame
885 Bytes
import gradio as gr
from transformers import pipeline
pipeline = pipeline("feature-extraction", model="WhereIsAI/UAE-Large-V1")
def predict(text):
title_outputs = pipeline(text)
title_outputs = torch.tensor(title_outputs)
# Mean pooling
title_embedding = title_outputs.mean(dim=1)
#title_embedding = title_outputs.last_hidden_state.mean(dim=1)
term_outputs = pipeline('multivitamin for men')
term_outputs = torch.tensor(term_outputs)
term_embedding = term_outputs.mean(dim=1)
#term_embedding = term_outputs.last_hidden_state.mean(dim=1)
semantic_score = cosine_similarity(title_embedding.flatten(), term_embedding.flatten()) * 100
return str(format(semantic_score,'.2f'))
gradio_app = gr.Interface(
predict,
inputs='text',
outputs='text',
title="Keyword Score",
)
if __name__ == "__main__":
gradio_app.launch()