ifmain commited on
Commit
bcaa150
1 Parent(s): 36fee4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -1,21 +1,20 @@
1
  import gradio as gr
2
  import torch
3
- from moderation import * # Убедитесь, что в moderation.py есть функции getEmb и predict
 
4
 
5
- # Загрузка модели
6
  moderation = ModerationModel()
7
- moderation.load_state_dict(torch.load('moderation_model.pth', map_location=torch.device('cpu')))
8
- moderation.eval() # Переключение модели в режим оценки
9
 
10
  def predict_moderation(text):
11
  embeddings_for_prediction = getEmb(text)
12
  prediction = predict(moderation, embeddings_for_prediction)
13
- # Предполагая, что prediction возвращает словарь с оценками и флагом обнаружения
14
- category_scores = prediction.get('category_scores', {}) # Извлечение оценок категорий из словаря
15
- detected = prediction.get('detected', False) # Извлечение флага обнаружения
16
- return category_scores, str(detected) # Преобразование detected в строку для отображения
17
 
18
- # Создание интерфейса Gradio
19
  iface = gr.Interface(fn=predict_moderation,
20
  inputs="text",
21
  outputs=[gr.outputs.Label(label="Category Scores", type="confidences"),
@@ -23,5 +22,5 @@ iface = gr.Interface(fn=predict_moderation,
23
  title="Moderation Model",
24
  description="Enter text to check for moderation flags.")
25
 
26
- # Запуск интерфейса
27
  iface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from moderation import *
4
+
5
 
 
6
  moderation = ModerationModel()
7
+ moderation.load_state_dict(torch.load('moderation_model.pth', map_location=torch.device('cpu'))) #Remove map_location if run on gpu
8
+ moderation.eval()
9
 
10
  def predict_moderation(text):
11
  embeddings_for_prediction = getEmb(text)
12
  prediction = predict(moderation, embeddings_for_prediction)
13
+ category_scores = prediction.get('category_scores', {})
14
+ detected = prediction.get('detected', False)
15
+ return category_scores, str(detected)
16
+
17
 
 
18
  iface = gr.Interface(fn=predict_moderation,
19
  inputs="text",
20
  outputs=[gr.outputs.Label(label="Category Scores", type="confidences"),
 
22
  title="Moderation Model",
23
  description="Enter text to check for moderation flags.")
24
 
25
+
26
  iface.launch()