Brice Vandeputte commited on
Commit
07775e1
1 Parent(s): b765164

revert accelerate, add webhook

Browse files
Files changed (2) hide show
  1. app.py +15 -10
  2. requirements.txt +1 -2
app.py CHANGED
@@ -2,10 +2,7 @@ import collections
2
  import heapq
3
  import json
4
  import os
5
-
6
- # https://huggingface.co/docs/accelerate/main/en/basic_tutorials/troubleshooting#logging
7
- # https://pypi.org/project/accelerate/
8
- from accelerate.logging import get_logger
9
 
10
  import gradio as gr
11
  import numpy as np
@@ -15,10 +12,11 @@ from open_clip import create_model, get_tokenizer
15
  from torchvision import transforms
16
 
17
  from templates import openai_imagenet_template
 
18
 
19
- # log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
20
- # logging.basicConfig(level=logging.INFO, format=log_format)
21
- logger = get_logger(__name__, log_level="INFO")
22
 
23
  hf_token = os.getenv("HF_TOKEN")
24
 
@@ -155,8 +153,6 @@ def open_domain_classification(img, rank: int) -> dict[str, float]:
155
 
156
  return {name: output[name] for name in topk_names}
157
 
158
-
159
-
160
  @torch.no_grad()
161
  def api_classification(img, rank: int): # -> dict[str, float]:
162
  """
@@ -185,7 +181,7 @@ def api_classification(img, rank: int): # -> dict[str, float]:
185
 
186
  logger.info(">>>>")
187
  logger.info(probs[0])
188
- return probs[0]
189
  # topk_names = heapq.nlargest(k, output, key=output.get)
190
  # return {name: output[name] for name in topk_names}
191
 
@@ -193,6 +189,13 @@ def api_classification(img, rank: int): # -> dict[str, float]:
193
  def change_output(choice):
194
  return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
195
 
 
 
 
 
 
 
 
196
 
197
  if __name__ == "__main__":
198
  logger.info("Starting.")
@@ -332,3 +335,5 @@ if __name__ == "__main__":
332
 
333
  app.queue(max_size=20)
334
  app.launch()
 
 
 
2
  import heapq
3
  import json
4
  import os
5
+ import logging
 
 
 
6
 
7
  import gradio as gr
8
  import numpy as np
 
12
  from torchvision import transforms
13
 
14
  from templates import openai_imagenet_template
15
+ from huggingface_hub import webhook_endpoint, WebhookPayload
16
 
17
+ log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
18
+ logging.basicConfig(level=logging.INFO, format=log_format)
19
+ logger = logging.getLogger()
20
 
21
  hf_token = os.getenv("HF_TOKEN")
22
 
 
153
 
154
  return {name: output[name] for name in topk_names}
155
 
 
 
156
  @torch.no_grad()
157
  def api_classification(img, rank: int): # -> dict[str, float]:
158
  """
 
181
 
182
  logger.info(">>>>")
183
  logger.info(probs[0])
184
+ return {"message": probs[0]}
185
  # topk_names = heapq.nlargest(k, output, key=output.get)
186
  # return {name: output[name] for name in topk_names}
187
 
 
189
  def change_output(choice):
190
  return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
191
 
192
+ @webhook_endpoint
193
+ # https://huggingface.co/docs/huggingface_hub/guides/webhooks_server
194
+ async def trigger_test(payload: WebhookPayload):
195
+ logger.info(payload)
196
+ if payload.repo.type == "dataset" and payload.event.action == "update":
197
+ logger.info("oo")
198
+ return {"message": "hello"}
199
 
200
  if __name__ == "__main__":
201
  logger.info("Starting.")
 
335
 
336
  app.queue(max_size=20)
337
  app.launch()
338
+
339
+ # app.py
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  open_clip_torch
2
  torchvision
3
  torch
4
- gradio
5
- accelerate
 
1
  open_clip_torch
2
  torchvision
3
  torch
4
+ gradio