Spaces:
Sleeping
Sleeping
File size: 2,085 Bytes
8fb0cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import streamlit as st
import pandas as pd
import numpy as np
from unidecode import unidecode
import tensorflow as tf
import cloudpickle
from transformers import DistilBertTokenizerFast
import os
def load_model():
interpreter = tf.lite.Interpreter(model_path=os.path.join("models/lang_detect_hf_distilbert.tflite"))
with open("models/lang_detect_labelencoder.bin", "rb") as model_file_obj:
label_encoder = cloudpickle.load(model_file_obj)
model_checkpoint = "distilbert-base-multilingual-cased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
return interpreter, label_encoder, tokenizer
interpreter, label_encoder, tokenizer = load_model()
def inference(text):
tflite_pred = "Can't Predict"
if text != "":
tokens = tokenizer(text, max_length=50, padding="max_length", truncation=True, return_tensors="tf")
# tflite model inference
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()[0]
attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
interpreter.set_tensor(input_details[0]["index"], attention_mask)
interpreter.set_tensor(input_details[1]["index"], input_ids)
interpreter.invoke()
tflite_pred = interpreter.get_tensor(output_details["index"])[0]
tflite_pred_argmax = np.argmax(tflite_pred)
tflite_pred = f"{label_encoder.inverse_transform([tflite_pred_argmax])[0].upper()} ({str(np.round(tflite_pred[tflite_pred_argmax], 3))})"
return tflite_pred
def main():
st.title("Language Detection")
lang_trained = 'eng, rus, ita, tur, epo, ber, deu, kab, fra, por, spa, hun, jpn, heb, ukr, nld, fin, pol, mkd, lit, cmn, mar, ces, dan'.upper()
st.write(f'Model is trained on the following languages \n{lang_trained}')
review = st.text_area("Enter Text:", "", height=200)
if st.button("Submit"):
result = inference(review)
st.write(result)
if __name__ == "__main__":
main()
|