DGSpitzer commited on
Commit
e6ad2e0
1 Parent(s): eca7b41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -35
app.py CHANGED
@@ -67,44 +67,40 @@ def translate_language(text_prompts):
67
  return {language_tips_text:gr.update(visible=True, value=tips_text), translated_language:text_prompts, trigger_component: gr.update(value=count, visible=False)}
68
 
69
 
 
 
 
 
70
  def get_result(text_prompts, style_indx):
71
- #results = text_to_img(text_prompts, style_indx, fn_index=3)
72
- print(text_prompts)
73
- #results = text_to_img(text_prompts, text_prompts, fn_index=3)
74
- print(style_indx)
75
- try:
76
  style = style_list[style_indx]
77
-
78
- #results = text_to_img(text_prompts, style_indx, fn_index=3)
79
-
80
- results = model.generate_image(
81
- text_prompts=text_prompts, style=style, visualization=False, topk=1)
82
- except Exception as e:
83
- error_text = str(e)
84
- return {video_result:None, status_text:error_text}
85
-
86
- #print("Ernie Vilg Output: " + str(results[:1]))
87
- print("Ernie Vilg Output test: " + str(results))
88
 
89
- #image_output = results[:1]
90
- image_output = results[:1]
91
-
92
- print("file name: " + image_output[0].filename)
93
-
94
- # Encode your PIL Image as a JPEG without writing to disk
95
- imagefile = "imageoutput.png"
96
- #img_np = np.array(image_output[0])
97
- #img_nparray= cv2.cvtColor(img_np, cv2.COLOR_BGR2RGBA)
98
- #img_blue_correction = Image.fromarray(img_nparray)
99
- #img_blue_correction.save(imagefile, img_blue_correction.format)
100
- image_output[0].save(imagefile, image_output[0].format)
101
-
102
- interrogate_prompt = img_to_text(imagefile, fn_index=1)[0]
103
- print(interrogate_prompt)
104
- spec_image, music_output = get_music(interrogate_prompt + ", " + style_list_EN[style_indx])
105
-
106
- video_merged = merge_video(music_output, image_output)
107
- return {spec_result:spec_image, video_result:video_merged, status_text:'Success'}
108
 
109
  def get_music(prompt):
110
  spec = pipe(prompt).images[0]
 
67
  return {language_tips_text:gr.update(visible=True, value=tips_text), translated_language:text_prompts, trigger_component: gr.update(value=count, visible=False)}
68
 
69
 
70
+ word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
71
+ word_list = word_list_dataset["train"]['text']
72
+
73
+
74
  def get_result(text_prompts, style_indx):
 
 
 
 
 
75
  style = style_list[style_indx]
76
+ prompt = "," + style
77
+ for filter in word_list:
78
+ if re.search(rf"\b{filter}\b", text_prompts):
79
+ raise gr.Error("Unsafe content found. Please try again with different prompts.")
80
+
81
+ model_id = "runwayml/stable-diffusion-v1-5"
82
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
83
+ pipe = pipe.to("cuda")
 
 
 
84
 
85
+ prompt = text_prompts
86
+ image_output = pipe(prompt).images[0]
87
+
88
+ print("file name: " + image_output.filename)
89
+
90
+ # Encode your PIL Image as a JPEG without writing to disk
91
+ imagefile = "imageoutput.png"
92
+ #img_np = np.array(image_output[0])
93
+ #img_nparray= cv2.cvtColor(img_np, cv2.COLOR_BGR2RGBA)
94
+ #img_blue_correction = Image.fromarray(img_nparray)
95
+ #img_blue_correction.save(imagefile, img_blue_correction.format)
96
+ image_output[0].save(imagefile, image_output[0].format)
97
+
98
+ interrogate_prompt = img_to_text(imagefile, fn_index=1)[0]
99
+ print(interrogate_prompt)
100
+ spec_image, music_output = get_music(interrogate_prompt + ", " + style_list_EN[style_indx])
101
+
102
+ video_merged = merge_video(music_output, image_output)
103
+ return {spec_result:spec_image, video_result:video_merged, status_text:'Success'}
104
 
105
  def get_music(prompt):
106
  spec = pipe(prompt).images[0]