MorenoLaQuatra
Map location device
9ccc55b
raw
history blame
2.9 kB
import gradio as gr
import torch
from dual_regression_model import DualRegressionModel
import transformers
from transformers import pipeline
from functools import partial
# load the models
# CLF: A-pt-bs16-dbmdz-bert-base-italian-cased
clf_model_tag = "clf_model/"
clf_tokenizer = transformers.AutoTokenizer.from_pretrained(clf_model_tag)
clf_model = transformers.AutoModelForSequenceClassification.from_pretrained(clf_model_tag)
clf_pipeline = pipeline("text-classification", model=clf_model, tokenizer=clf_tokenizer)
# REG
reg_model_tag = "distilbert-base-multilingual-cased"
reg_model_folder = "reg_model/regression_model.pt"
reg_model = DualRegressionModel(model_name_or_path=reg_model_tag)
reg_model.load_model(reg_model_folder)
# define the function to be used for prediction
def predict(text):
# predict the class
clf_prediction = clf_pipeline(text)[0]
# predict the coordinates
reg_input = reg_model.tokenizer(text, return_tensors="pt")
reg_prediction = reg_model(reg_input)
latitude, longitude = reg_prediction["latitude"].item(), reg_prediction["longitude"].item()
lat_min = 38
lat_max = 46
long_min = 8
long_max = 18
# return the results
html_output = f"<h3>The identified region is: {clf_prediction['label']}</h3>"
# plot points on the map of Italy
html_output += f'<h3>Predicted point on map:</h3><p>Latitude: {latitude}</p><p>Longitude: {longitude}</p>'
html_output += f'<iframe width="425" height="350" frameborder="0" scrolling="no" marginheight="0" marginwidth="0" src="https://www.openstreetmap.org/export/embed.html?bbox={long_min}%2C{lat_min}%2C{long_max}%2C{lat_max}&amp;layer=mapnik&marker={latitude}%2C{longitude}" style="border: 1px solid black"></iframe><br/><small><a href="https://www.openstreetmap.org/#map=13/{latitude}/{longitude}">Visualizza mappa ingrandita</a></small>'
return html_output
# --------------------------------------------------------------------------------------------
# Gradio interface
# --------------------------------------------------------------------------------------------
# define the interface
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=2, placeholder="Insert the text here..."),
outputs=gr.HTML(),
title="DANTE: Dialect ANalysis TEam",
description="This is a demo of a classification and regression model for locating the italian dialect of a given text.",
examples=[
["Bisognerebbe saperli materializzare .... !! Ma ovviamente .. belin .... NO SE PEU SCIUSCIA' E SCIORBI'"],
["Guaglio' Buongiorno! Azz! Vir te si scurdat puparuol e mulignane pero '!! E che se fa😑"],
["Il massimo...ghe ne minga par nisun"],
["Che poi a me la tuta piace na cifra da vede. Subisco un po' lo stigma sociale che noi con la fregna dovemo stà sempre apposto.",]
]
)
# launch the interface
iface.launch()