marcelomoreno26's picture
Update model_functions.py
0cb9394 verified
raw
history blame contribute delete
No virus
4.52 kB
import torch
from transformers import (AutoModelForSequenceClassification, AutoModelForSeq2SeqLM,
AutoConfig, AutoModelForTokenClassification,
AutoTokenizer, pipeline)
from peft import PeftModel, PeftConfig
import streamlit as st
@st.cache_resource
def load_sentiment_analyzer():
tokenizer = AutoTokenizer.from_pretrained("aliciiavs/sentiment-analysis-whatsapp2")
model = AutoModelForSequenceClassification.from_pretrained("aliciiavs/sentiment-analysis-whatsapp2")
return tokenizer, model
@st.cache_resource
def load_summarizer():
config = PeftConfig.from_pretrained("marcelomoreno26/bart-large-samsum-adapter")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large")
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
tokenizer.pad_token = tokenizer.eos_token
model = PeftModel.from_pretrained(model, "marcelomoreno26/bart-large-samsum-adapter", config=config)
model = model.merge_and_unload()
return tokenizer, model
@st.cache_resource
def load_NER():
config = AutoConfig.from_pretrained("hannahisrael03/wikineural-multilingual-ner-finetuned-wikiann")
model = AutoModelForTokenClassification.from_pretrained("hannahisrael03/wikineural-multilingual-ner-finetuned-wikiann",config=config)
tokenizer = AutoTokenizer.from_pretrained("hannahisrael03/wikineural-multilingual-ner-finetuned-wikiann")
pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="average")
return pipe
def get_sentiment_analysis(text, tokenizer, model):
inputs = tokenizer(text, padding=True, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Get predicted probabilities and predicted label
probabilities = torch.softmax(outputs.logits, dim=1)
predicted_label = torch.argmax(probabilities, dim=1)
# Convert the predicted label tensor to a Python integer
predicted_label = predicted_label.item()
# Map predicted label index to sentiment label
label_dic = {0: 'sadness', 1: 'joy', 2: 'love', 3: 'anger', 4: 'fear', 5: 'surprise'}
# Print the predicted sentiment label
return label_dic[predicted_label]
def generate_summary(text, tokenizer, model):
prefix = "summarize: "
encoded_input = tokenizer.encode_plus(prefix + text, return_tensors='pt', add_special_tokens=True)
input_ids = encoded_input['input_ids']
# Check if input_ids exceed the model's max length
max_length = 512
if input_ids.shape[1] > max_length:
# Split the input_ids into manageable segments
total_summary = []
for i in range(0, input_ids.shape[1], max_length - 50): # We use max_length - 50 to allow for some room for the model to generate context
segment_ids = input_ids[:, i:i + max_length]
output_ids = model.generate(segment_ids, max_length=150, num_beams=5, early_stopping=True)
segment_summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
total_summary.append(segment_summary)
# Concatenate all segment summaries
summary = ' '.join(total_summary)
else:
# Process as usual
output_ids = model.generate(input_ids, max_length=150, num_beams=5, early_stopping=True)
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return summary
def get_NER(text, pipe):
# Use pipeline to predict NER
results = pipe(text)
# Filter duplicates while retaining the highest score for each entity type and word combination
unique_entities = {}
for ent in results:
key = (ent['entity_group'], ent['word'])
if key not in unique_entities or unique_entities[key]['score'] < ent['score']:
unique_entities[key] = ent
# Prepare the output, sorted by the start position to maintain the order they appear in the text
filtered_results = sorted(unique_entities.values(), key=lambda x: x['start'])
# Format the results for a table display
formatted_results = [[ent['word'], ent['entity_group']] for ent in filtered_results]
filtered_results = []
for entity in formatted_results:
if entity[1] == 'ORG':
# Split the 'word' by spaces and count the number of words
if len(entity[0].split()) <= 2:
filtered_results.append(entity)
else:
# Add non-ORG entities without filtering
filtered_results.append(entity)
return filtered_results