karwanjiru commited on
Commit
b043438
1 Parent(s): 416e2a6
Files changed (1) hide show
  1. app.py +12 -20
app.py CHANGED
@@ -12,16 +12,8 @@ import requests
12
  from io import BytesIO
13
 
14
  # Paths and model setup
15
- image_folder = "path_to_your_image_folder" # Specify the path to your image folder
16
  model_path = "MichalMlodawski/nsfw-image-detection-large"
17
 
18
- # List of jpg files in the folder
19
- jpg_files = [file for file in os.listdir(image_folder) if file.lower().endswith(".jpg")]
20
-
21
- if not jpg_files:
22
- print("🚫 No jpg files found in folder:", image_folder)
23
- exit()
24
-
25
  # Load the model and feature extractor
26
  feature_extractor = AutoProcessor.from_pretrained(model_path)
27
  model = FocalNetForImageClassification.from_pretrained(model_path)
@@ -143,6 +135,18 @@ def moderate_image(image):
143
  else:
144
  return "Image does not adhere to community guidelines."
145
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  # Create the Gradio interface
147
  css = """
148
  #col-container {
@@ -216,18 +220,6 @@ with gr.Blocks(css=css) as demo:
216
  selected_image = gr.Image(type="pil", label="Upload Image for NSFW Classification")
217
  classify_button = gr.Button("Classify Image")
218
  classification_result = gr.Textbox(label="Classification Result")
219
-
220
- def classify_nsfw(image):
221
- image_tensor = transform(image).unsqueeze(0)
222
- inputs = feature_extractor(images=image, return_tensors="pt")
223
- with torch.no_grad():
224
- outputs = model(**inputs)
225
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
226
- confidence, predicted = torch.max(probabilities, 1)
227
- label = model.config.id2label[predicted.item()]
228
- category = label_to_category.get(label, "Unknown")
229
- return f"Label: {label}, Category: {category}, Confidence: {confidence.item() * 100:.2f}%"
230
-
231
  classify_button.click(classify_nsfw, selected_image, classification_result)
232
 
233
  demo.launch()
 
12
  from io import BytesIO
13
 
14
  # Paths and model setup
 
15
  model_path = "MichalMlodawski/nsfw-image-detection-large"
16
 
 
 
 
 
 
 
 
17
  # Load the model and feature extractor
18
  feature_extractor = AutoProcessor.from_pretrained(model_path)
19
  model = FocalNetForImageClassification.from_pretrained(model_path)
 
135
  else:
136
  return "Image does not adhere to community guidelines."
137
 
138
+ # Function to classify NSFW images
139
+ def classify_nsfw(image):
140
+ image_tensor = transform(image).unsqueeze(0)
141
+ inputs = feature_extractor(images=image, return_tensors="pt")
142
+ with torch.no_grad():
143
+ outputs = model(**inputs)
144
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
145
+ confidence, predicted = torch.max(probabilities, 1)
146
+ label = model.config.id2label[predicted.item()]
147
+ category = label_to_category.get(label, "Unknown")
148
+ return f"Label: {label}, Category: {category}, Confidence: {confidence.item() * 100:.2f}%"
149
+
150
  # Create the Gradio interface
151
  css = """
152
  #col-container {
 
220
  selected_image = gr.Image(type="pil", label="Upload Image for NSFW Classification")
221
  classify_button = gr.Button("Classify Image")
222
  classification_result = gr.Textbox(label="Classification Result")
 
 
 
 
 
 
 
 
 
 
 
 
223
  classify_button.click(classify_nsfw, selected_image, classification_result)
224
 
225
  demo.launch()