ifmain commited on
Commit
4a20809
1 Parent(s): ded2d5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -28,7 +28,7 @@ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
28
  print(DEVICE)
29
 
30
  yoloModel = YOLO('yolov8x-seg.pt')
31
- pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to("cuda")
32
  sdxl.to("cuda")
33
 
34
  image_captioner = pipeline("image-to-text", model="Abdou/vit-swin-base-224-gpt2-image-captioning", device=DEVICE)
@@ -54,8 +54,8 @@ def get_most_similar_string(target_string, string_array):
54
 
55
 
56
  # Yolo
57
- def getClasses(model, img1):
58
- results = model([np.array(img1)], device='cpu') # Изменение для передачи изображения как массива NumPy
59
  out = []
60
  for r in results:
61
  im_array = r.plot()
@@ -101,8 +101,8 @@ def joinClasses(classes):
101
  return allMask
102
 
103
 
104
- def getSegments(yoloModel, img1):
105
- classes, image, results1 = getClasses(yoloModel, img1)
106
  allMask = joinClasses(classes)
107
  return allMask
108
 
@@ -140,7 +140,7 @@ def ChangeOBJ(img1, response, mask1):
140
  def full_pipeline(image, target):
141
  img1 = Image.fromarray(image.astype('uint8'), 'RGB')
142
  img1 = img1.resize((512, 512))
143
- allMask = getSegments(yoloModel, img1)
144
  tartget_to_remove = get_most_similar_string(target, list(allMask.keys()))
145
  caption = getDescript(image_captioner, img1)
146
 
 
28
  print(DEVICE)
29
 
30
  yoloModel = YOLO('yolov8x-seg.pt')
31
+ sdxl = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to("cuda")
32
  sdxl.to("cuda")
33
 
34
  image_captioner = pipeline("image-to-text", model="Abdou/vit-swin-base-224-gpt2-image-captioning", device=DEVICE)
 
54
 
55
 
56
  # Yolo
57
+ def getClasses(img1):
58
+ results = yoloModel([np.array(img1)], device='cpu') # Изменение для передачи изображения как массива NumPy
59
  out = []
60
  for r in results:
61
  im_array = r.plot()
 
101
  return allMask
102
 
103
 
104
+ def getSegments(img1):
105
+ classes, image, results1 = getClasses(img1)
106
  allMask = joinClasses(classes)
107
  return allMask
108
 
 
140
  def full_pipeline(image, target):
141
  img1 = Image.fromarray(image.astype('uint8'), 'RGB')
142
  img1 = img1.resize((512, 512))
143
+ allMask = getSegments(img1)
144
  tartget_to_remove = get_most_similar_string(target, list(allMask.keys()))
145
  caption = getDescript(image_captioner, img1)
146