MuhammadHanif commited on
Commit
d0c86d7
1 Parent(s): acd4224

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -6
app.py CHANGED
@@ -17,10 +17,10 @@ pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
17
  use_memory_efficient_attention=True
18
  )
19
 
20
- def infer(prompts, negative_prompts):
21
 
22
  num_samples = 1 #jax.device_count()
23
- rng = create_key(0)
24
  rng = jax.random.split(rng, jax.device_count())
25
 
26
  prompt_ids = pipe.prepare_inputs([prompts] * num_samples)
@@ -33,10 +33,10 @@ def infer(prompts, negative_prompts):
33
  output = pipe(
34
  prompt_ids=prompt_ids,
35
  params=p_params,
36
- height=1088,
37
- width=1088,
38
  prng_seed=rng,
39
- num_inference_steps=50,
40
  neg_prompt_ids=negative_prompt_ids,
41
  jit=True,
42
  ).images
@@ -44,4 +44,41 @@ def infer(prompts, negative_prompts):
44
  output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
45
  return output_images
46
 
47
- gr.Interface(infer, inputs=["text", "text"], outputs="gallery").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  use_memory_efficient_attention=True
18
  )
19
 
20
+ def infer(prompts, negative_prompts, width=1088, height=1088, inference_steps=30, seed=0):
21
 
22
  num_samples = 1 #jax.device_count()
23
+ rng = create_key(seed)
24
  rng = jax.random.split(rng, jax.device_count())
25
 
26
  prompt_ids = pipe.prepare_inputs([prompts] * num_samples)
 
33
  output = pipe(
34
  prompt_ids=prompt_ids,
35
  params=p_params,
36
+ height=height,
37
+ width=width,
38
  prng_seed=rng,
39
+ num_inference_steps=inference_steps,
40
  neg_prompt_ids=negative_prompt_ids,
41
  jit=True,
42
  ).images
 
44
  output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
45
  return output_images
46
 
47
+ prompt_input = gr.inputs.Textbox(
48
+ label="Prompt",
49
+ placeholder="a highly detailed mansion in the autumn by studio ghibli, makoto shinkai"
50
+ )
51
+ neg_prompt_input = gr.inputs.Textbox(
52
+ label="Negative Prompt",
53
+ placeholder=""
54
+ )
55
+
56
+ width_slider = gr.inputs.Slider(
57
+ minimum=512, maximum=2048, default=30, step=64, label="width"
58
+ )
59
+
60
+ height_slider = gr.inputs.Slider(
61
+ minimum=512, maximum=2048, default=30, step=64, label="height"
62
+ )
63
+
64
+ inf_steps_input = gr.inputs.Slider(
65
+ minimum=1, maximum=100, default=30, step=1, label="Inference Steps"
66
+ )
67
+
68
+
69
+ seed_input = gr.inputs.Number(default=0, label="Seed")
70
+
71
+ app = gr.Interface(
72
+ fn=infer,
73
+ inputs=[prompt_input, neg_prompt_input, width_slider, height_slider, inf_steps_input, seed_input],
74
+ outputs="image",
75
+ title="Stable Diffusion High Resolution",
76
+ description=(
77
+ "Based on stable diffusion 1.5 and fine-tuned on 576x576 up to 1088x1088 images, "
78
+ "Stable Diffusion High Resolution is compartible with another SD1.5 model and mergeable with other SD1.5 model, "
79
+ "giving other model to generate high resolution images without using upscaler."
80
+ ),
81
+ examples=[["a highly detailed mansion in the autumn by studio ghibli, makoto shinkai","", 1088, 1088, 30, 0]],
82
+ )
83
+
84
+ app.launch()