|
import streamlit as st |
|
import torch |
|
from peft import PeftModel |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
model_name = 'intfloat/multilingual-e5-large' |
|
adapters_name = './checkpoint-21170' |
|
|
|
|
|
model = AutoModel.from_pretrained(model_name) |
|
model = PeftModel.from_pretrained(model, adapters_name) |
|
model = model.merge_and_unload() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
description = st.text_input("Product description") |
|
review = st.text_input("Review") |
|
|
|
if description and review: |
|
input_texts = [ |
|
f'query: {review}', |
|
f'passage: {description}' |
|
] |
|
batch_dict = tokenizer(input_texts, max_length=512, |
|
padding=True, truncation=True, return_tensors='pt') |
|
|
|
query_embedding, doc_embedding = model(**batch_dict, return_dict=True).pooler_output |
|
|
|
similarity = torch.nn.functional.cosine_similarity( |
|
query_embedding, doc_embedding) |
|
|
|
threshold = 0.7 |
|
|
|
if similarity > threshold: |
|
st.write('Relevant') |
|
else: |
|
st.write('Irrelevant') |
|
|