Create worker_runpod.py
Browse files- worker_runpod.py +102 -0
worker_runpod.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, json, tempfile, requests, runpod
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
from transformers import AutoModelForImageSegmentation
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
def download_file(url, save_dir='/content/input'):
|
9 |
+
os.makedirs(save_dir, exist_ok=True)
|
10 |
+
file_name = url.split('/')[-1]
|
11 |
+
file_path = os.path.join(save_dir, file_name)
|
12 |
+
response = requests.get(url)
|
13 |
+
response.raise_for_status()
|
14 |
+
with open(file_path, 'wb') as file:
|
15 |
+
file.write(response.content)
|
16 |
+
return file_path
|
17 |
+
|
18 |
+
with torch.inference_mode():
|
19 |
+
torch.set_float32_matmul_precision(["high", "highest"][0])
|
20 |
+
birefnet = AutoModelForImageSegmentation.from_pretrained("/content/BiRefNet", trust_remote_code=True).to("cuda")
|
21 |
+
transform_image = transforms.Compose(
|
22 |
+
[
|
23 |
+
transforms.Resize((1024, 1024)),
|
24 |
+
transforms.ToTensor(),
|
25 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
26 |
+
]
|
27 |
+
)
|
28 |
+
|
29 |
+
@torch.inference_mode()
|
30 |
+
def generate(input):
|
31 |
+
values = input["input"]
|
32 |
+
|
33 |
+
input_image_url = values['input_image_url']
|
34 |
+
input_image = download_file(input_image_url)
|
35 |
+
|
36 |
+
input_image = Image.open(input_image)
|
37 |
+
image_size = input_image.size
|
38 |
+
input_images = transform_image(input_image).unsqueeze(0).to("cuda")
|
39 |
+
preds = birefnet(input_images)[-1].sigmoid().cpu()
|
40 |
+
pred = preds[0].squeeze()
|
41 |
+
pred_pil = transforms.ToPILImage()(pred)
|
42 |
+
mask = pred_pil.resize(image_size)
|
43 |
+
input_image.putalpha(mask)
|
44 |
+
input_image.save("/content/birefnet_tost.png")
|
45 |
+
|
46 |
+
result = "/content/birefnet_tost.png"
|
47 |
+
try:
|
48 |
+
notify_uri = values['notify_uri']
|
49 |
+
del values['notify_uri']
|
50 |
+
notify_token = values['notify_token']
|
51 |
+
del values['notify_token']
|
52 |
+
discord_id = values['discord_id']
|
53 |
+
del values['discord_id']
|
54 |
+
if(discord_id == "discord_id"):
|
55 |
+
discord_id = os.getenv('com_camenduru_discord_id')
|
56 |
+
discord_channel = values['discord_channel']
|
57 |
+
del values['discord_channel']
|
58 |
+
if(discord_channel == "discord_channel"):
|
59 |
+
discord_channel = os.getenv('com_camenduru_discord_channel')
|
60 |
+
discord_token = values['discord_token']
|
61 |
+
del values['discord_token']
|
62 |
+
if(discord_token == "discord_token"):
|
63 |
+
discord_token = os.getenv('com_camenduru_discord_token')
|
64 |
+
job_id = values['job_id']
|
65 |
+
del values['job_id']
|
66 |
+
default_filename = os.path.basename(result)
|
67 |
+
with open(result, "rb") as file:
|
68 |
+
files = {default_filename: file.read()}
|
69 |
+
payload = {"content": f"{json.dumps(values)} <@{discord_id}>"}
|
70 |
+
response = requests.post(
|
71 |
+
f"https://discord.com/api/v9/channels/{discord_channel}/messages",
|
72 |
+
data=payload,
|
73 |
+
headers={"Authorization": f"Bot {discord_token}"},
|
74 |
+
files=files
|
75 |
+
)
|
76 |
+
response.raise_for_status()
|
77 |
+
result_url = response.json()['attachments'][0]['url']
|
78 |
+
notify_payload = {"jobId": job_id, "result": result_url, "status": "DONE"}
|
79 |
+
web_notify_uri = os.getenv('com_camenduru_web_notify_uri')
|
80 |
+
web_notify_token = os.getenv('com_camenduru_web_notify_token')
|
81 |
+
if(notify_uri == "notify_uri"):
|
82 |
+
requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
|
83 |
+
else:
|
84 |
+
requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
|
85 |
+
requests.post(notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token})
|
86 |
+
return {"jobId": job_id, "result": result_url, "status": "DONE"}
|
87 |
+
except Exception as e:
|
88 |
+
error_payload = {"jobId": job_id, "status": "FAILED"}
|
89 |
+
try:
|
90 |
+
if(notify_uri == "notify_uri"):
|
91 |
+
requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
|
92 |
+
else:
|
93 |
+
requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
|
94 |
+
requests.post(notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token})
|
95 |
+
except:
|
96 |
+
pass
|
97 |
+
return {"jobId": job_id, "result": f"FAILED: {str(e)}", "status": "FAILED"}
|
98 |
+
finally:
|
99 |
+
if os.path.exists(result):
|
100 |
+
os.remove(result)
|
101 |
+
|
102 |
+
runpod.serverless.start({"handler": generate})
|