Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModel, AutoTokenizer, PretrainedConfig, PreTrainedModel, MT5EncoderModel | |
class MTRankerConfig(PretrainedConfig): | |
def __init__(self, backbone='google/mt5-base', **kwargs): | |
self.backbone = backbone | |
super().__init__(**kwargs) | |
class MTRanker(PreTrainedModel): | |
config_class = MTRankerConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.encoder = MT5EncoderModel.from_pretrained(config.backbone) | |
self.num_classes = 2 | |
self.classifier = torch.nn.Linear(self.encoder.config.hidden_size, self.num_classes) | |
def forward(self, input_ids, attention_mask): | |
encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state | |
seq_lengths = torch.sum(attention_mask, keepdim=True, dim=1) | |
pooled_hidden_state = torch.sum(encoder_output * attention_mask.unsqueeze(-1).expand(-1, -1, self.encoder.config.hidden_size), dim=1) | |
pooled_hidden_state /= seq_lengths | |
prediction_logit = self.classifier(pooled_hidden_state) | |
return prediction_logit | |
config = MTRankerConfig(backbone='google/mt5-base') | |
tokenizer = AutoTokenizer.from_pretrained(config.backbone) | |
model = MTRanker.from_pretrained('ibraheemmoosa/mt-ranker-base') | |
def predict(source, translation1, translation2): | |
model_input = "Source: {} Translation 0: {} Translation 1: {}".format(source, translation1, translation2) | |
inputs = tokenizer([model_input], max_length=512, padding='max_length', truncation=True, return_tensors='pt') | |
with torch.inference_mode(): | |
logits = model(inputs.input_ids, inputs.attention_mask) | |
output_scores = torch.softmax(logits, dim=1) | |
output_scores = output_scores[0] | |
return {'Translation 1': output_scores[0], 'Translation 2': output_scores[1]} | |
source_textbox = gr.Textbox(label="Source", info="Source Sentence", value="Le chat est sur la tapis.") | |
translation1_textbox = gr.Textbox(label="Translation 1", info="Translation 1", value="The cat is on the bed.") | |
translation2_textbox = gr.Textbox(label="Translation 2", info="Translation 2", value="The cat is on the carpet.") | |
output = gr.Label(label="Result") | |
iface = gr.Interface(fn=predict, inputs=[source_textbox, translation1_textbox, translation2_textbox], outputs=output) | |
iface.launch() |