amiguel commited on
Commit
3aafe68
1 Parent(s): 2e5be48

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -2
app.py CHANGED
@@ -1,3 +1,59 @@
1
- import gradio as gr
 
 
2
 
3
- gr.load("models/amiguel/fintune_naming_model").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ import torch
4
 
5
+ model_name = "amiguel/fintune_naming_model" # Replace with your model repo
6
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ model.to(device)
10
+
11
+ def classify_review(text, model, tokenizer, device, max_length=512):
12
+ model.eval()
13
+ inputs = tokenizer.encode_plus(
14
+ text,
15
+ truncation=True,
16
+ padding='max_length',
17
+ max_length=max_length,
18
+ return_tensors="pt"
19
+ )
20
+ input_ids = inputs['input_ids'].to(device)
21
+ attention_mask = inputs['attention_mask'].to(device)
22
+
23
+ with torch.no_grad():
24
+ outputs = model(input_ids, attention_mask=attention_mask)
25
+ logits = outputs.logits
26
+ predicted_label = torch.argmax(logits, dim=-1).item()
27
+ return "Proper Naming otfcn" if predicted_label == 1 else "Wrong Naming notfcn"
28
+
29
+ def main():
30
+ st.title("Notifications Naming Classifier")
31
+
32
+ input_option = st.radio("Select input option", ("Single Text Query", "Upload Table"))
33
+
34
+ if input_option == "Single Text Query":
35
+ text_query = st.text_input("Enter text query")
36
+ if st.button("Classify"):
37
+ if text_query:
38
+ predicted_label = classify_review(text_query, model, tokenizer, device)
39
+ st.write("Predicted Label:")
40
+ st.write(predicted_label)
41
+ else:
42
+ st.warning("Please enter a text query.")
43
+
44
+ elif input_option == "Upload Table":
45
+ uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
46
+ if uploaded_file is not None:
47
+ import pandas as pd
48
+ if uploaded_file.name.endswith(".csv"):
49
+ df = pd.read_csv(uploaded_file)
50
+ else:
51
+ df = pd.read_excel(uploaded_file)
52
+
53
+ text_column = st.selectbox("Select the text column", df.columns)
54
+ predicted_labels = [classify_review(text, model, tokenizer, device) for text in df[text_column]]
55
+ df["Predicted Label"] = predicted_labels
56
+ st.write(df)
57
+
58
+ if __name__ == "__main__":
59
+ main()