Update worker_runpod.py
Browse files- worker_runpod.py +16 -6
worker_runpod.py
CHANGED
@@ -29,16 +29,27 @@ def download_image(url, download_dir="asset"):
|
|
29 |
if not os.path.exists(download_dir):
|
30 |
os.makedirs(download_dir, exist_ok=True)
|
31 |
|
32 |
-
#
|
33 |
-
filename = f"{uuid.uuid4().hex}.png"
|
34 |
-
file_path = os.path.join(download_dir, filename)
|
35 |
-
|
36 |
-
# Download the image and save it
|
37 |
response = requests.get(url, stream=True)
|
38 |
if response.status_code == 200:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
with open(file_path, "wb") as f:
|
40 |
for chunk in response.iter_content(1024):
|
41 |
f.write(chunk)
|
|
|
42 |
print(f"Image downloaded to {file_path}")
|
43 |
return file_path
|
44 |
else:
|
@@ -47,7 +58,6 @@ def download_image(url, download_dir="asset"):
|
|
47 |
# Usage
|
48 |
# validation_image_start = values.get("validation_image_start", "https://example.com/path/to/image.png")
|
49 |
# downloaded_image_path = download_image(validation_image_start)
|
50 |
-
# Model loading section
|
51 |
model_id = "/content/model"
|
52 |
transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
|
53 |
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
|
|
29 |
if not os.path.exists(download_dir):
|
30 |
os.makedirs(download_dir, exist_ok=True)
|
31 |
|
32 |
+
# Send the request and check for successful response
|
|
|
|
|
|
|
|
|
33 |
response = requests.get(url, stream=True)
|
34 |
if response.status_code == 200:
|
35 |
+
# Determine file extension based on content type
|
36 |
+
content_type = response.headers.get("Content-Type")
|
37 |
+
if content_type == "image/png":
|
38 |
+
ext = "png"
|
39 |
+
elif content_type == "image/jpeg":
|
40 |
+
ext = "jpg"
|
41 |
+
else:
|
42 |
+
ext = "jpg" # default to .jpg if content type is unrecognized
|
43 |
+
|
44 |
+
# Generate a random filename with the correct extension
|
45 |
+
filename = f"{uuid.uuid4().hex}.{ext}"
|
46 |
+
file_path = os.path.join(download_dir, filename)
|
47 |
+
|
48 |
+
# Save the image
|
49 |
with open(file_path, "wb") as f:
|
50 |
for chunk in response.iter_content(1024):
|
51 |
f.write(chunk)
|
52 |
+
|
53 |
print(f"Image downloaded to {file_path}")
|
54 |
return file_path
|
55 |
else:
|
|
|
58 |
# Usage
|
59 |
# validation_image_start = values.get("validation_image_start", "https://example.com/path/to/image.png")
|
60 |
# downloaded_image_path = download_image(validation_image_start)
|
|
|
61 |
model_id = "/content/model"
|
62 |
transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
|
63 |
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|