File size: 2,085 Bytes
8fb0cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import streamlit as st
import pandas as pd
import numpy as np
from unidecode import unidecode
import tensorflow as tf 
import cloudpickle
from transformers import DistilBertTokenizerFast
import os

def load_model():
    interpreter = tf.lite.Interpreter(model_path=os.path.join("models/lang_detect_hf_distilbert.tflite"))
    with open("models/lang_detect_labelencoder.bin", "rb") as model_file_obj:
        label_encoder = cloudpickle.load(model_file_obj)
        
    model_checkpoint = "distilbert-base-multilingual-cased"
    tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
    return interpreter, label_encoder, tokenizer

interpreter, label_encoder, tokenizer = load_model()

def inference(text):
    tflite_pred = "Can't Predict"
    if text != "":
        tokens = tokenizer(text, max_length=50, padding="max_length", truncation=True, return_tensors="tf")
        # tflite model inference  
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()[0]
        attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
        interpreter.set_tensor(input_details[0]["index"], attention_mask)
        interpreter.set_tensor(input_details[1]["index"], input_ids)
        interpreter.invoke()
        tflite_pred = interpreter.get_tensor(output_details["index"])[0]
        tflite_pred_argmax = np.argmax(tflite_pred)
        tflite_pred = f"{label_encoder.inverse_transform([tflite_pred_argmax])[0].upper()} ({str(np.round(tflite_pred[tflite_pred_argmax], 3))})"
    return tflite_pred


def main():
    st.title("Language Detection")
    lang_trained = 'eng, rus, ita, tur, epo, ber, deu, kab, fra, por, spa, hun, jpn, heb, ukr, nld, fin, pol, mkd, lit, cmn, mar, ces, dan'.upper()
    st.write(f'Model is trained on the following languages  \n{lang_trained}')
    review = st.text_area("Enter Text:", "", height=200)
    if st.button("Submit"):
        result = inference(review)
        st.write(result)

if __name__ == "__main__":
    main()