Tinsae commited on
Commit
268bb4a
1 Parent(s): 2f7bd8a
Files changed (2) hide show
  1. app.py +82 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ from rembg import remove
4
+ import cv2
5
+ import os
6
+ from torchvision.transforms import GaussianBlur
7
+ import gradio as gr
8
+ import replicate
9
+ import requests
10
+ from io import BytesIO
11
+
12
+ def create_mask(input):
13
+ input_path = 'input.png'
14
+ bg_removed_path = 'bg_removed.png'
15
+ mask_name = 'blured_mask.png'
16
+
17
+ input.save(input_path)
18
+ bg_removed = remove(input)
19
+ bg_removed = bg_removed.resize((512, 512))
20
+ bg_removed.save(bg_removed_path)
21
+
22
+ img2_grayscale = bg_removed.convert('L')
23
+ img2_a = np.array(img2_grayscale)
24
+
25
+ mask = np.array(img2_grayscale)
26
+ threshhold = 0
27
+ mask[img2_a==threshhold] = 1
28
+ mask[img2_a>threshhold] = 0
29
+
30
+ strength = 1
31
+ d = int(255 * (1-strength))
32
+ mask *= 255-d
33
+ mask += d
34
+
35
+ mask = Image.fromarray(mask)
36
+
37
+ blur = GaussianBlur(11,20)
38
+ mask = blur(mask)
39
+ mask = mask.resize((512, 512))
40
+
41
+ mask.save(mask_name)
42
+
43
+ return Image.open(mask_name)
44
+
45
+
46
+ def generate_image(image, product_name, target_name):
47
+ mask = create_mask(image)
48
+ image = image.resize((512, 512))
49
+ mask = mask.resize((512,512))
50
+ guidance_scale=16
51
+ num_samples = 1
52
+
53
+ prompt = 'a photo of a ' + product_name + ' with ' + target_name + ' product photograpy'
54
+
55
+ model = replicate.models.get("cjwbw/stable-diffusion-v2-inpainting")
56
+ version = model.versions.get("f9bb0632bfdceb83196e85521b9b55895f8ff3d1d3b487fd1973210c0eb30bec")
57
+ output = version.predict(prompt=prompt, image=open("bg_removed.png", "rb"), mask=open("blured_mask.png", "rb"))
58
+ response = requests.get(output[0])
59
+
60
+ return Image.open(BytesIO(response.content))
61
+
62
+ with gr.Blocks() as demo:
63
+ gr.Markdown("# Advertise better with AI")
64
+ # with gr.Tab("Prompt Paint - Basic"):
65
+ with gr.Row():
66
+
67
+ with gr.Column():
68
+ input_image = gr.Image(label = "Upload your product's photo", type = 'pil')
69
+
70
+ product_name = gr.Textbox(label="Describe your product")
71
+ target_name = gr.Textbox(label="Where do you want to put your product?")
72
+ # result_prompt = product_name + ' in ' + target_name + 'product photograpy ultrarealist'
73
+
74
+ image_button = gr.Button("Generate")
75
+
76
+ with gr.Column():
77
+ image_output = gr.Image()
78
+
79
+ image_button.click(generate_image, inputs=[input_image, product_name, target_name ], outputs=image_output, api_name='test')
80
+
81
+
82
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ rembg==2.0.30
2
+ torchvision==0.13.1
3
+ numpy==1.23.5
4
+ Pillow==9.3.0
5
+ opencv-python==4.6.0.66
6
+ gradio==3.9.1