twn39 commited on
Commit
85e7f21
1 Parent(s): 44247fb
Files changed (5) hide show
  1. app.py +51 -13
  2. pyproject.toml +9 -3
  3. requirements-dev.lock +14 -8
  4. requirements.lock +14 -8
  5. requirements.txt +14 -9
app.py CHANGED
@@ -8,6 +8,32 @@ from modelscope.utils.constant import Tasks
8
  import cv2
9
  from diffusers import StableDiffusionXLPipeline
10
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  deep_seek_llm = DeepSeekLLM(api_key=settings.deep_seek_api_key)
@@ -71,22 +97,26 @@ def bg_remove(_image, _type):
71
  return vis_img
72
 
73
 
74
- def text_to_image(_image, _prompt):
 
75
  t2i_pipeline = StableDiffusionXLPipeline.from_pretrained(
76
  "stabilityai/stable-diffusion-xl-base-1.0",
77
  torch_dtype=torch.float16,
78
  variant="fp16",
79
  use_safetensors=True,
80
  ).to("cuda")
81
- result = t2i_pipeline(
82
- prompt=_prompt,
83
- negative_prompt='ugly',
84
- num_inference_steps=22,
85
- width=1024,
86
- height=1024,
87
- guidance_scale=7,
88
- ).images[0]
89
- return result
 
 
 
90
 
91
 
92
  with gr.Blocks() as app:
@@ -254,10 +284,18 @@ with gr.Blocks() as app:
254
  image = gr.Image()
255
 
256
  with gr.Column(scale=1, min_width=300):
257
- with gr.Accordion(label="图像生成"):
258
  prompt = gr.Textbox(label="提示语", value="", lines=3)
259
- t2i_btn = gr.Button('画图', variant='primary')
260
- t2i_btn.click(fn=text_to_image, inputs=[prompt, image], outputs=[image])
 
 
 
 
 
 
 
 
261
 
262
 
263
  app.launch(debug=settings.debug, show_api=False)
 
8
  import cv2
9
  from diffusers import StableDiffusionXLPipeline
10
  import torch
11
+ from diffusers import (
12
+ DiffusionPipeline,
13
+ DPMSolverMultistepScheduler,
14
+ DDIMScheduler,
15
+ HeunDiscreteScheduler,
16
+ EulerAncestralDiscreteScheduler,
17
+ EulerDiscreteScheduler,
18
+ PNDMScheduler
19
+ )
20
+
21
+
22
+ class KarrasDPM:
23
+ @staticmethod
24
+ def from_config(config):
25
+ return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True)
26
+
27
+
28
+ SCHEDULERS = {
29
+ "DDIM": DDIMScheduler,
30
+ "DPMSolverMultistep": DPMSolverMultistepScheduler,
31
+ "HeunDiscrete": HeunDiscreteScheduler,
32
+ "KarrasDPM": KarrasDPM,
33
+ "K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler,
34
+ "K_EULER": EulerDiscreteScheduler,
35
+ "PNDM": PNDMScheduler,
36
+ }
37
 
38
 
39
  deep_seek_llm = DeepSeekLLM(api_key=settings.deep_seek_api_key)
 
97
  return vis_img
98
 
99
 
100
+ def text_to_image(_prompt: str, _n_prompt: str, _scheduler: str, _inference_steps: int, _w: int, _h: int, _guidance_scale: float):
101
+ print('????????', _prompt, _scheduler, _inference_steps, _w, _h, _guidance_scale)
102
  t2i_pipeline = StableDiffusionXLPipeline.from_pretrained(
103
  "stabilityai/stable-diffusion-xl-base-1.0",
104
  torch_dtype=torch.float16,
105
  variant="fp16",
106
  use_safetensors=True,
107
  ).to("cuda")
108
+ t2i_pipeline.scheduler = SCHEDULERS[_scheduler].from_config(t2i_pipeline.scheduler.config)
109
+ t2i_pipeline.enable_xformers_memory_efficient_attention()
110
+ with torch.inference_mode():
111
+ result = t2i_pipeline(
112
+ prompt=_prompt,
113
+ negative_prompt=_n_prompt,
114
+ num_inference_steps=_inference_steps,
115
+ width=_w,
116
+ height=_h,
117
+ guidance_scale=_guidance_scale,
118
+ ).images[0]
119
+ return result
120
 
121
 
122
  with gr.Blocks() as app:
 
284
  image = gr.Image()
285
 
286
  with gr.Column(scale=1, min_width=300):
287
+ with gr.Accordion(label="提示词", open=True):
288
  prompt = gr.Textbox(label="提示语", value="", lines=3)
289
+ negative_prompt = gr.Textbox(label="负提示语", value="ugly", lines=2)
290
+ with gr.Accordion(label="参数设置", open=False):
291
+ scheduler = gr.Dropdown(label='scheduler', choices=list(SCHEDULERS.keys()), value='KarrasDPM')
292
+ inference_steps = gr.Number(label='inference steps', value=22, minimum=1, maximum=100)
293
+ width = gr.Dropdown(label='width', choices=[512, 768, 832, 896, 1024, 1152], value=1024)
294
+ height = gr.Dropdown(label='height', choices=[512, 768, 832, 896, 1024, 1152], value=1024)
295
+ guidance_scale = gr.Number(label='guidance scale', value=7.0, minimum=1.0, maximum=10.0)
296
+ with gr.Row(variant='panel'):
297
+ t2i_btn = gr.Button('🪄生成', variant='primary')
298
+ t2i_btn.click(fn=text_to_image, inputs=[prompt, negative_prompt, scheduler, inference_steps, width, height, guidance_scale], outputs=[image])
299
 
300
 
301
  app.launch(debug=settings.debug, show_api=False)
pyproject.toml CHANGED
@@ -14,9 +14,9 @@ dependencies = [
14
  "langchain-openai>=0.1.16",
15
  "dashscope>=1.20.1",
16
  "modelscope[framework]>=1.16.1",
17
- "torch==2.3.1",
18
- "torchvision>=0.18.1",
19
- "torchaudio>=2.3.1",
20
  "setuptools==69.5.1",
21
  "oss2>=2.18.6",
22
  "kornia>=0.7.3",
@@ -24,6 +24,7 @@ dependencies = [
24
  "tensorflow>=2.17.0",
25
  "transformers>=4.42.4",
26
  "accelerate>=0.32.1",
 
27
  ]
28
  readme = "README.md"
29
  requires-python = ">= 3.8"
@@ -36,6 +37,11 @@ build-backend = "hatchling.build"
36
  managed = true
37
  dev-dependencies = []
38
 
 
 
 
 
 
39
  [tool.hatch.metadata]
40
  allow-direct-references = true
41
 
 
14
  "langchain-openai>=0.1.16",
15
  "dashscope>=1.20.1",
16
  "modelscope[framework]>=1.16.1",
17
+ "torch==2.3.1+cu121",
18
+ "torchvision==0.18.1+cu121",
19
+ "torchaudio==2.3.1+cu121",
20
  "setuptools==69.5.1",
21
  "oss2>=2.18.6",
22
  "kornia>=0.7.3",
 
24
  "tensorflow>=2.17.0",
25
  "transformers>=4.42.4",
26
  "accelerate>=0.32.1",
27
+ "xformers>=0.0.27",
28
  ]
29
  readme = "README.md"
30
  requires-python = ">= 3.8"
 
37
  managed = true
38
  dev-dependencies = []
39
 
40
+ [[tool.rye.sources]]
41
+ name = "torch"
42
+ url = "https://download.pytorch.org/whl/cu121"
43
+ type = "index"
44
+
45
  [tool.hatch.metadata]
46
  allow-direct-references = true
47
 
requirements-dev.lock CHANGED
@@ -7,6 +7,7 @@
7
  # all-features: false
8
  # with-sources: false
9
  # generate-hashes: false
 
10
 
11
  -e file:.
12
  absl-py==2.1.0
@@ -221,6 +222,7 @@ ml-dtypes==0.4.0
221
  # via tensorflow-intel
222
  modelscope==1.16.1
223
  # via aitoolkits-webui
 
224
  mpmath==1.3.0
225
  # via sympy
226
  multidict==6.0.5
@@ -253,6 +255,7 @@ numpy==1.26.4
253
  # via tensorflow-intel
254
  # via torchvision
255
  # via transformers
 
256
  openai==1.35.13
257
  # via langchain-openai
258
  opencv-python==4.10.0.84
@@ -378,11 +381,6 @@ scipy==1.14.0
378
  # via modelscope
379
  semantic-version==2.10.0
380
  # via gradio
381
- setuptools==69.5.1
382
- # via aitoolkits-webui
383
- # via modelscope
384
- # via tensorboard
385
- # via tensorflow-intel
386
  shellingham==1.5.4
387
  # via typer
388
  simplejson==3.19.2
@@ -431,15 +429,16 @@ tomlkit==0.12.0
431
  # via gradio
432
  toolz==0.12.1
433
  # via altair
434
- torch==2.3.1
435
  # via accelerate
436
  # via aitoolkits-webui
437
  # via kornia
438
  # via torchaudio
439
  # via torchvision
440
- torchaudio==2.3.1
 
441
  # via aitoolkits-webui
442
- torchvision==0.18.1
443
  # via aitoolkits-webui
444
  tqdm==4.66.4
445
  # via datasets
@@ -492,9 +491,16 @@ wheel==0.43.0
492
  # via astunparse
493
  wrapt==1.16.0
494
  # via tensorflow-intel
 
 
495
  xxhash==3.4.1
496
  # via datasets
497
  yarl==1.9.4
498
  # via aiohttp
499
  zipp==3.19.2
500
  # via importlib-metadata
 
 
 
 
 
 
7
  # all-features: false
8
  # with-sources: false
9
  # generate-hashes: false
10
+ # universal: false
11
 
12
  -e file:.
13
  absl-py==2.1.0
 
222
  # via tensorflow-intel
223
  modelscope==1.16.1
224
  # via aitoolkits-webui
225
+ # via modelscope
226
  mpmath==1.3.0
227
  # via sympy
228
  multidict==6.0.5
 
255
  # via tensorflow-intel
256
  # via torchvision
257
  # via transformers
258
+ # via xformers
259
  openai==1.35.13
260
  # via langchain-openai
261
  opencv-python==4.10.0.84
 
381
  # via modelscope
382
  semantic-version==2.10.0
383
  # via gradio
 
 
 
 
 
384
  shellingham==1.5.4
385
  # via typer
386
  simplejson==3.19.2
 
429
  # via gradio
430
  toolz==0.12.1
431
  # via altair
432
+ torch==2.3.1+cu121
433
  # via accelerate
434
  # via aitoolkits-webui
435
  # via kornia
436
  # via torchaudio
437
  # via torchvision
438
+ # via xformers
439
+ torchaudio==2.3.1+cu121
440
  # via aitoolkits-webui
441
+ torchvision==0.18.1+cu121
442
  # via aitoolkits-webui
443
  tqdm==4.66.4
444
  # via datasets
 
491
  # via astunparse
492
  wrapt==1.16.0
493
  # via tensorflow-intel
494
+ xformers==0.0.27
495
+ # via aitoolkits-webui
496
  xxhash==3.4.1
497
  # via datasets
498
  yarl==1.9.4
499
  # via aiohttp
500
  zipp==3.19.2
501
  # via importlib-metadata
502
+ setuptools==69.5.1
503
+ # via aitoolkits-webui
504
+ # via modelscope
505
+ # via tensorboard
506
+ # via tensorflow-intel
requirements.lock CHANGED
@@ -7,6 +7,7 @@
7
  # all-features: false
8
  # with-sources: false
9
  # generate-hashes: false
 
10
 
11
  -e file:.
12
  absl-py==2.1.0
@@ -221,6 +222,7 @@ ml-dtypes==0.4.0
221
  # via tensorflow-intel
222
  modelscope==1.16.1
223
  # via aitoolkits-webui
 
224
  mpmath==1.3.0
225
  # via sympy
226
  multidict==6.0.5
@@ -253,6 +255,7 @@ numpy==1.26.4
253
  # via tensorflow-intel
254
  # via torchvision
255
  # via transformers
 
256
  openai==1.35.13
257
  # via langchain-openai
258
  opencv-python==4.10.0.84
@@ -378,11 +381,6 @@ scipy==1.14.0
378
  # via modelscope
379
  semantic-version==2.10.0
380
  # via gradio
381
- setuptools==69.5.1
382
- # via aitoolkits-webui
383
- # via modelscope
384
- # via tensorboard
385
- # via tensorflow-intel
386
  shellingham==1.5.4
387
  # via typer
388
  simplejson==3.19.2
@@ -431,15 +429,16 @@ tomlkit==0.12.0
431
  # via gradio
432
  toolz==0.12.1
433
  # via altair
434
- torch==2.3.1
435
  # via accelerate
436
  # via aitoolkits-webui
437
  # via kornia
438
  # via torchaudio
439
  # via torchvision
440
- torchaudio==2.3.1
 
441
  # via aitoolkits-webui
442
- torchvision==0.18.1
443
  # via aitoolkits-webui
444
  tqdm==4.66.4
445
  # via datasets
@@ -492,9 +491,16 @@ wheel==0.43.0
492
  # via astunparse
493
  wrapt==1.16.0
494
  # via tensorflow-intel
 
 
495
  xxhash==3.4.1
496
  # via datasets
497
  yarl==1.9.4
498
  # via aiohttp
499
  zipp==3.19.2
500
  # via importlib-metadata
 
 
 
 
 
 
7
  # all-features: false
8
  # with-sources: false
9
  # generate-hashes: false
10
+ # universal: false
11
 
12
  -e file:.
13
  absl-py==2.1.0
 
222
  # via tensorflow-intel
223
  modelscope==1.16.1
224
  # via aitoolkits-webui
225
+ # via modelscope
226
  mpmath==1.3.0
227
  # via sympy
228
  multidict==6.0.5
 
255
  # via tensorflow-intel
256
  # via torchvision
257
  # via transformers
258
+ # via xformers
259
  openai==1.35.13
260
  # via langchain-openai
261
  opencv-python==4.10.0.84
 
381
  # via modelscope
382
  semantic-version==2.10.0
383
  # via gradio
 
 
 
 
 
384
  shellingham==1.5.4
385
  # via typer
386
  simplejson==3.19.2
 
429
  # via gradio
430
  toolz==0.12.1
431
  # via altair
432
+ torch==2.3.1+cu121
433
  # via accelerate
434
  # via aitoolkits-webui
435
  # via kornia
436
  # via torchaudio
437
  # via torchvision
438
+ # via xformers
439
+ torchaudio==2.3.1+cu121
440
  # via aitoolkits-webui
441
+ torchvision==0.18.1+cu121
442
  # via aitoolkits-webui
443
  tqdm==4.66.4
444
  # via datasets
 
491
  # via astunparse
492
  wrapt==1.16.0
493
  # via tensorflow-intel
494
+ xformers==0.0.27
495
+ # via aitoolkits-webui
496
  xxhash==3.4.1
497
  # via datasets
498
  yarl==1.9.4
499
  # via aiohttp
500
  zipp==3.19.2
501
  # via importlib-metadata
502
+ setuptools==69.5.1
503
+ # via aitoolkits-webui
504
+ # via modelscope
505
+ # via tensorboard
506
+ # via tensorflow-intel
requirements.txt CHANGED
@@ -7,6 +7,7 @@
7
  # all-features: false
8
  # with-sources: false
9
  # generate-hashes: false
 
10
 
11
  absl-py==2.1.0
12
  # via keras
@@ -220,6 +221,7 @@ ml-dtypes==0.4.0
220
  # via tensorflow-intel
221
  modelscope==1.16.1
222
  # via aitoolkits-webui
 
223
  mpmath==1.3.0
224
  # via sympy
225
  multidict==6.0.5
@@ -252,6 +254,7 @@ numpy==1.26.4
252
  # via tensorflow-intel
253
  # via torchvision
254
  # via transformers
 
255
  openai==1.35.13
256
  # via langchain-openai
257
  opencv-python==4.10.0.84
@@ -377,11 +380,6 @@ scipy==1.14.0
377
  # via modelscope
378
  semantic-version==2.10.0
379
  # via gradio
380
- setuptools==69.5.1
381
- # via aitoolkits-webui
382
- # via modelscope
383
- # via tensorboard
384
- # via tensorflow-intel
385
  shellingham==1.5.4
386
  # via typer
387
  simplejson==3.19.2
@@ -416,7 +414,6 @@ tensorboard-data-server==0.7.2
416
  # via tensorboard
417
  tensorflow==2.17.0
418
  # via aitoolkits-webui
419
- # via tensorflow
420
  tensorflow-io-gcs-filesystem==0.31.0
421
  # via tensorflow-intel
422
  termcolor==2.4.0
@@ -429,15 +426,16 @@ tomlkit==0.12.0
429
  # via gradio
430
  toolz==0.12.1
431
  # via altair
432
- torch==2.3.1
433
  # via accelerate
434
  # via aitoolkits-webui
435
  # via kornia
436
  # via torchaudio
437
  # via torchvision
438
- torchaudio==2.3.1
 
439
  # via aitoolkits-webui
440
- torchvision==0.18.1
441
  # via aitoolkits-webui
442
  tqdm==4.66.4
443
  # via datasets
@@ -490,9 +488,16 @@ wheel==0.43.0
490
  # via astunparse
491
  wrapt==1.16.0
492
  # via tensorflow-intel
 
 
493
  xxhash==3.4.1
494
  # via datasets
495
  yarl==1.9.4
496
  # via aiohttp
497
  zipp==3.19.2
498
  # via importlib-metadata
 
 
 
 
 
 
7
  # all-features: false
8
  # with-sources: false
9
  # generate-hashes: false
10
+ # universal: false
11
 
12
  absl-py==2.1.0
13
  # via keras
 
221
  # via tensorflow-intel
222
  modelscope==1.16.1
223
  # via aitoolkits-webui
224
+ # via modelscope
225
  mpmath==1.3.0
226
  # via sympy
227
  multidict==6.0.5
 
254
  # via tensorflow-intel
255
  # via torchvision
256
  # via transformers
257
+ # via xformers
258
  openai==1.35.13
259
  # via langchain-openai
260
  opencv-python==4.10.0.84
 
380
  # via modelscope
381
  semantic-version==2.10.0
382
  # via gradio
 
 
 
 
 
383
  shellingham==1.5.4
384
  # via typer
385
  simplejson==3.19.2
 
414
  # via tensorboard
415
  tensorflow==2.17.0
416
  # via aitoolkits-webui
 
417
  tensorflow-io-gcs-filesystem==0.31.0
418
  # via tensorflow-intel
419
  termcolor==2.4.0
 
426
  # via gradio
427
  toolz==0.12.1
428
  # via altair
429
+ torch==2.3.1+cu121
430
  # via accelerate
431
  # via aitoolkits-webui
432
  # via kornia
433
  # via torchaudio
434
  # via torchvision
435
+ # via xformers
436
+ torchaudio==2.3.1+cu121
437
  # via aitoolkits-webui
438
+ torchvision==0.18.1+cu121
439
  # via aitoolkits-webui
440
  tqdm==4.66.4
441
  # via datasets
 
488
  # via astunparse
489
  wrapt==1.16.0
490
  # via tensorflow-intel
491
+ xformers==0.0.27
492
+ # via aitoolkits-webui
493
  xxhash==3.4.1
494
  # via datasets
495
  yarl==1.9.4
496
  # via aiohttp
497
  zipp==3.19.2
498
  # via importlib-metadata
499
+ setuptools==69.5.1
500
+ # via aitoolkits-webui
501
+ # via modelscope
502
+ # via tensorboard
503
+ # via tensorflow-intel