AlekseyCalvin commited on
Commit
5741f84
1 Parent(s): 5b1fc08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -5
app.py CHANGED
@@ -8,15 +8,57 @@ from diffusers import DiffusionPipeline
8
  import copy
9
  import random
10
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Load LoRAs from JSON file
13
  with open('loras.json', 'r') as f:
14
  loras = json.load(f)
15
-
16
- # Initialize the base model
17
- base_model = "sayakpaul/FLUX.1-merged"
18
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
19
-
20
  MAX_SEED = 2**32-1
21
 
22
  class calculateDuration:
 
8
  import copy
9
  import random
10
  import time
11
+ from huggingface_hub import hf_hub_download
12
+ from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
13
+ from accelerate import init_empty_weights
14
+ from convert_nf4_flux import replace_with_bnb_linear, create_quantized_param, check_quantized_param
15
+ from diffusers import FluxTransformer2DModel, FluxPipeline
16
+ import safetensors.torch
17
+ import gc
18
+ import torch
19
+
20
+ # Set dtype and check for float8 support
21
+ dtype = torch.bfloat16
22
+ is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
23
+
24
+ ckpt_path = hf_hub_download("ABDALLALSWAITI/Maxwell", filename="diffusion_pytorch_model.safetensors")
25
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
26
+
27
+ with init_empty_weights():
28
+ config = FluxTransformer2DModel.load_config("ABDALLALSWAITI/Maxwell")
29
+ model = FluxTransformer2DModel.from_config(config).to(dtype)
30
+ expected_state_dict_keys = list(model.state_dict().keys())
31
+
32
+ # Load the state dict into the quantized model
33
+ for param_name, param in original_state_dict.items():
34
+ if param_name not in expected_state_dict_keys:
35
+ continue
36
+
37
+ is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
38
+ if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
39
+ param = param.to(dtype)
40
+
41
+ if not check_quantized_param(model, param_name):
42
+ set_module_tensor_to_device(model, param_name, device=0, value=param)
43
+ else:
44
+ create_quantized_param(
45
+ model, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=True
46
+ )
47
+
48
+ # Clean up
49
+ del original_state_dict
50
+ gc.collect()
51
+
52
+ # Print model size
53
+ print(compute_module_sizes(model)[""] / 1024 / 1204)
54
+
55
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
56
+ pipe.enable_model_cpu_offload()
57
 
58
  # Load LoRAs from JSON file
59
  with open('loras.json', 'r') as f:
60
  loras = json.load(f)
61
+
 
 
 
 
62
  MAX_SEED = 2**32-1
63
 
64
  class calculateDuration: