File size: 2,610 Bytes
268bb4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06f5362
 
 
 
 
 
f35feb4
06f5362
 
268bb4a
06f5362
268bb4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6ee16c
268bb4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from PIL import Image
import numpy as np
from rembg import remove
import cv2
import os
from torchvision.transforms import GaussianBlur
import gradio as gr
import replicate
import requests
from io import BytesIO

def create_mask(input):
    input_path = 'input.png'
    bg_removed_path = 'bg_removed.png'
    mask_name = 'blured_mask.png'
    
    input.save(input_path)
    bg_removed = remove(input)

    width, height = bg_removed.size
    max_dim = max(width, height)
    square_img = Image.new('RGB', (max_dim, max_dim), (255, 255, 255))
    paste_pos = ((max_dim - width) // 2, (max_dim - height) // 2)
    square_img.paste(bg_removed, paste_pos)
    
    square_img = square_img.resize((512, 512))
    square_img.save(bg_removed_path)

    img2_grayscale = square_img.convert('L')
    img2_a = np.array(img2_grayscale)

    mask = np.array(img2_grayscale)
    threshhold = 0
    mask[img2_a==threshhold] = 1  
    mask[img2_a>threshhold] = 0 

    strength = 1  
    d = int(255 * (1-strength))
    mask *= 255-d 
    mask += d

    mask = Image.fromarray(mask)

    blur = GaussianBlur(11,20)
    mask = blur(mask)
    mask = mask.resize((512, 512))

    mask.save(mask_name)

    return Image.open(mask_name)


def generate_image(image, product_name, target_name):
  mask = create_mask(image)
  image = image.resize((512, 512))
  mask = mask.resize((512,512))
  guidance_scale=16
  num_samples = 1

  prompt = 'a product photography photo of' + product_name + ' on ' + target_name 
  
  model = replicate.models.get("cjwbw/stable-diffusion-v2-inpainting")
  version = model.versions.get("f9bb0632bfdceb83196e85521b9b55895f8ff3d1d3b487fd1973210c0eb30bec")
  output = version.predict(prompt=prompt, image=open("bg_removed.png", "rb"), mask=open("blured_mask.png", "rb"))
  response = requests.get(output[0])

  return Image.open(BytesIO(response.content))

with gr.Blocks() as demo:
    gr.Markdown("# Advertise better with AI")
    # with gr.Tab("Prompt Paint - Basic"):
    with gr.Row():

      with gr.Column():
        input_image = gr.Image(label = "Upload your product's photo", type = 'pil')

        product_name = gr.Textbox(label="Describe your product")
        target_name = gr.Textbox(label="Where do you want to put your product?")
        # result_prompt = product_name + ' in ' + target_name + 'product photograpy ultrarealist'

        image_button = gr.Button("Generate")
      
      with gr.Column():
        image_output = gr.Image()
    
    image_button.click(generate_image, inputs=[input_image, product_name, target_name ], outputs=image_output, api_name='test')


demo.launch()