zR commited on
Commit
e1d35f0
1 Parent(s): 4ddbeae

use cogview3 model

Browse files
Files changed (1) hide show
  1. app.py +6 -18
app.py CHANGED
@@ -5,25 +5,15 @@ import time
5
  from datetime import datetime, timedelta
6
 
7
  import gradio as gr
8
- import numpy as np
9
  import random
10
  import spaces # [uncomment to use ZeroGPU]
11
- from diffusers import FluxPipeline
12
  import torch
13
  from openai import OpenAI
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
- model_repo_id = "black-forest-labs/FLUX.1-dev"
17
 
18
- if torch.cuda.is_available():
19
- torch_dtype = torch.float16
20
- else:
21
- torch_dtype = torch.float32
22
-
23
- pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
24
- pipe = pipe.to(device)
25
-
26
- MAX_SEED = np.iinfo(np.int32).max
27
 
28
 
29
  def clean_string(s):
@@ -134,9 +124,7 @@ threading.Thread(target=delete_old_files, daemon=True).start()
134
  def infer(prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
135
  progress=gr.Progress(track_tqdm=True)):
136
  if randomize_seed:
137
- seed = random.randint(0, MAX_SEED)
138
-
139
- generator = torch.Generator().manual_seed(seed)
140
 
141
  image = pipe(
142
  prompt=prompt,
@@ -145,7 +133,7 @@ def infer(prompt, seed, randomize_seed, width, height, guidance_scale, num_infer
145
  num_inference_steps=num_inference_steps,
146
  width=width,
147
  height=height,
148
- generator=generator
149
  ).images[0]
150
  return image, seed
151
 
@@ -207,7 +195,7 @@ with gr.Blocks(css=css) as demo:
207
  seed = gr.Slider(
208
  label="Seed",
209
  minimum=0,
210
- maximum=MAX_SEED,
211
  step=1,
212
  value=0,
213
  )
@@ -220,7 +208,7 @@ with gr.Blocks(css=css) as demo:
220
  minimum=512,
221
  maximum=2048,
222
  step=32,
223
- value=1024, # 替换为你的模型默认值
224
  )
225
 
226
  height = gr.Slider(
 
5
  from datetime import datetime, timedelta
6
 
7
  import gradio as gr
 
8
  import random
9
  import spaces # [uncomment to use ZeroGPU]
10
+ from diffusers import CogView3PlusPipeline
11
  import torch
12
  from openai import OpenAI
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
 
16
+ pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16).to(device)
 
 
 
 
 
 
 
 
17
 
18
 
19
  def clean_string(s):
 
124
  def infer(prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
125
  progress=gr.Progress(track_tqdm=True)):
126
  if randomize_seed:
127
+ seed = random.randint(0, 65536)
 
 
128
 
129
  image = pipe(
130
  prompt=prompt,
 
133
  num_inference_steps=num_inference_steps,
134
  width=width,
135
  height=height,
136
+ generator=torch.Generator().manual_seed(seed)
137
  ).images[0]
138
  return image, seed
139
 
 
195
  seed = gr.Slider(
196
  label="Seed",
197
  minimum=0,
198
+ maximum=65536,
199
  step=1,
200
  value=0,
201
  )
 
208
  minimum=512,
209
  maximum=2048,
210
  step=32,
211
+ value=1024,
212
  )
213
 
214
  height = gr.Slider(