ifmain commited on
Commit
568fb5d
1 Parent(s): 2ffd8bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -13,6 +13,8 @@ import base64
13
  from io import BytesIO
14
  import difflib
15
 
 
 
16
 
17
  # Helper functions
18
  def image_to_base64(image: Image.Image):
@@ -90,6 +92,7 @@ def getSegments(yoloModel, img1):
90
 
91
  @spaces.GPU
92
  def getDescript(image_captioner, img1):
 
93
  base64_img = image_to_base64(img1)
94
  caption = image_captioner(base64_img)[0]['generated_text']
95
  return caption
@@ -110,9 +113,10 @@ def rmGPT(caption, remove_class, change):
110
  return ' '.join(arstr)
111
 
112
  @spaces.GPU
113
- def ChangeOBJ(sdxl_m, img1, response, mask1):
 
114
  size = img1.size
115
- image = sdxl_m(prompt=response, image=img1, mask_image=mask1).images[0]
116
  return image.resize((size[0], size[1]))
117
 
118
  # Load models initially
 
13
  from io import BytesIO
14
  import difflib
15
 
16
+ # Constants
17
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
18
 
19
  # Helper functions
20
  def image_to_base64(image: Image.Image):
 
92
 
93
  @spaces.GPU
94
  def getDescript(image_captioner, img1):
95
+ image_captioner.model.to('cuda')
96
  base64_img = image_to_base64(img1)
97
  caption = image_captioner(base64_img)[0]['generated_text']
98
  return caption
 
113
  return ' '.join(arstr)
114
 
115
  @spaces.GPU
116
+ def ChangeOBJ(sdxl_model, img1, response, mask1):
117
+ sdxl_model.to('cuda')
118
  size = img1.size
119
+ image = sdxl_model(prompt=response, image=img1, mask_image=mask1).images[0]
120
  return image.resize((size[0], size[1]))
121
 
122
  # Load models initially