meepmoo commited on
Commit
0e71afa
1 Parent(s): f56d744

Update worker_runpod.py

Browse files
Files changed (1) hide show
  1. worker_runpod.py +16 -6
worker_runpod.py CHANGED
@@ -29,16 +29,27 @@ def download_image(url, download_dir="asset"):
29
  if not os.path.exists(download_dir):
30
  os.makedirs(download_dir, exist_ok=True)
31
 
32
- # Generate a random filename with a .png extension
33
- filename = f"{uuid.uuid4().hex}.png"
34
- file_path = os.path.join(download_dir, filename)
35
-
36
- # Download the image and save it
37
  response = requests.get(url, stream=True)
38
  if response.status_code == 200:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  with open(file_path, "wb") as f:
40
  for chunk in response.iter_content(1024):
41
  f.write(chunk)
 
42
  print(f"Image downloaded to {file_path}")
43
  return file_path
44
  else:
@@ -47,7 +58,6 @@ def download_image(url, download_dir="asset"):
47
  # Usage
48
  # validation_image_start = values.get("validation_image_start", "https://example.com/path/to/image.png")
49
  # downloaded_image_path = download_image(validation_image_start)
50
- # Model loading section
51
  model_id = "/content/model"
52
  transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
53
  model_id, subfolder="transformer", torch_dtype=torch.bfloat16
 
29
  if not os.path.exists(download_dir):
30
  os.makedirs(download_dir, exist_ok=True)
31
 
32
+ # Send the request and check for successful response
 
 
 
 
33
  response = requests.get(url, stream=True)
34
  if response.status_code == 200:
35
+ # Determine file extension based on content type
36
+ content_type = response.headers.get("Content-Type")
37
+ if content_type == "image/png":
38
+ ext = "png"
39
+ elif content_type == "image/jpeg":
40
+ ext = "jpg"
41
+ else:
42
+ ext = "jpg" # default to .jpg if content type is unrecognized
43
+
44
+ # Generate a random filename with the correct extension
45
+ filename = f"{uuid.uuid4().hex}.{ext}"
46
+ file_path = os.path.join(download_dir, filename)
47
+
48
+ # Save the image
49
  with open(file_path, "wb") as f:
50
  for chunk in response.iter_content(1024):
51
  f.write(chunk)
52
+
53
  print(f"Image downloaded to {file_path}")
54
  return file_path
55
  else:
 
58
  # Usage
59
  # validation_image_start = values.get("validation_image_start", "https://example.com/path/to/image.png")
60
  # downloaded_image_path = download_image(validation_image_start)
 
61
  model_id = "/content/model"
62
  transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
63
  model_id, subfolder="transformer", torch_dtype=torch.bfloat16