aifeifei798 commited on
Commit
17187d2
1 Parent(s): a62e14b

Update extras/expansion.py

Browse files
Files changed (1) hide show
  1. extras/expansion.py +129 -129
extras/expansion.py CHANGED
@@ -1,129 +1,129 @@
1
- # Fooocus GPT2 Expansion
2
- # Algorithm created by Lvmin Zhang at 2023, Stanford
3
- # If used inside Fooocus, any use is permitted.
4
- # If used outside Fooocus, only non-commercial use is permitted (CC-By NC 4.0).
5
- # This applies to the word list, vocab, model, and algorithm.
6
-
7
-
8
- import os
9
- import torch
10
- import math
11
- import ldm_patched.modules.model_management as model_management
12
-
13
- from transformers.generation.logits_process import LogitsProcessorList
14
- from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
15
- # from modules.config import path_fooocus_expansion
16
- from ldm_patched.modules.model_patcher import ModelPatcher
17
-
18
- path_fooocus_expansion ="extras/fooocus_expansion"
19
- # limitation of np.random.seed(), called from transformers.set_seed()
20
- SEED_LIMIT_NUMPY = 2**32
21
- neg_inf = - 8192.0
22
-
23
-
24
- def safe_str(x):
25
- x = str(x)
26
- for _ in range(16):
27
- x = x.replace(' ', ' ')
28
- return x.strip(",. \r\n")
29
-
30
-
31
- def remove_pattern(x, pattern):
32
- for p in pattern:
33
- x = x.replace(p, '')
34
- return x
35
-
36
-
37
- class FooocusExpansion:
38
- def __init__(self):
39
- self.tokenizer = AutoTokenizer.from_pretrained(path_fooocus_expansion)
40
-
41
- positive_words = open(os.path.join(path_fooocus_expansion, 'positive.txt'),
42
- encoding='utf-8').read().splitlines()
43
- positive_words = ['Ġ' + x.lower() for x in positive_words if x != '']
44
-
45
- self.logits_bias = torch.zeros((1, len(self.tokenizer.vocab)), dtype=torch.float32) + neg_inf
46
-
47
- debug_list = []
48
- for k, v in self.tokenizer.vocab.items():
49
- if k in positive_words:
50
- self.logits_bias[0, v] = 0
51
- debug_list.append(k[1:])
52
-
53
- print(f'Fooocus V2 Expansion: Vocab with {len(debug_list)} words.')
54
-
55
- # debug_list = '\n'.join(sorted(debug_list))
56
- # print(debug_list)
57
-
58
- # t11 = self.tokenizer(',', return_tensors="np")
59
- # t198 = self.tokenizer('\n', return_tensors="np")
60
- # eos = self.tokenizer.eos_token_id
61
-
62
- self.model = AutoModelForCausalLM.from_pretrained(path_fooocus_expansion)
63
- self.model.eval()
64
-
65
- load_device = model_management.text_encoder_device()
66
- offload_device = model_management.text_encoder_offload_device()
67
-
68
- # MPS hack
69
- if model_management.is_device_mps(load_device):
70
- load_device = torch.device('cpu')
71
- offload_device = torch.device('cpu')
72
-
73
- use_fp16 = model_management.should_use_fp16(device=load_device)
74
-
75
- if use_fp16:
76
- self.model.half()
77
-
78
- self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device)
79
- print(f'Fooocus Expansion engine loaded for {load_device}, use_fp16 = {use_fp16}.')
80
-
81
- @torch.no_grad()
82
- @torch.inference_mode()
83
- def logits_processor(self, input_ids, scores):
84
- assert scores.ndim == 2 and scores.shape[0] == 1
85
- self.logits_bias = self.logits_bias.to(scores)
86
-
87
- bias = self.logits_bias.clone()
88
- bias[0, input_ids[0].to(bias.device).long()] = neg_inf
89
- bias[0, 11] = 0
90
-
91
- return scores + bias
92
-
93
- @torch.no_grad()
94
- @torch.inference_mode()
95
- def __call__(self, prompt, seed):
96
- if prompt == '':
97
- return ''
98
-
99
- if self.patcher.current_device != self.patcher.load_device:
100
- print('Fooocus Expansion loaded by itself.')
101
- model_management.load_model_gpu(self.patcher)
102
-
103
- seed = int(seed) % SEED_LIMIT_NUMPY
104
- set_seed(seed)
105
- prompt = safe_str(prompt) + ','
106
-
107
- tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt")
108
- tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device)
109
- tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device)
110
-
111
- current_token_length = int(tokenized_kwargs.data['input_ids'].shape[1])
112
- max_token_length = 75 * int(math.ceil(float(current_token_length) / 75.0))
113
- max_new_tokens = max_token_length - current_token_length
114
-
115
- if max_new_tokens == 0:
116
- return prompt[:-1]
117
-
118
- # https://huggingface.co/blog/introducing-csearch
119
- # https://huggingface.co/docs/transformers/generation_strategies
120
- features = self.model.generate(**tokenized_kwargs,
121
- top_k=100,
122
- max_new_tokens=max_new_tokens,
123
- do_sample=True,
124
- logits_processor=LogitsProcessorList([self.logits_processor]))
125
-
126
- response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
127
- result = safe_str(response[0])
128
-
129
- return result
 
1
+ # Fooocus GPT2 Expansion
2
+ # Algorithm created by Lvmin Zhang at 2023, Stanford
3
+ # If used inside Fooocus, any use is permitted.
4
+ # If used outside Fooocus, only non-commercial use is permitted (CC-By NC 4.0).
5
+ # This applies to the word list, vocab, model, and algorithm.
6
+
7
+
8
+ import os
9
+ import torch
10
+ import math
11
+ import ldm_patched.modules.model_management as model_management
12
+
13
+ from transformers.generation.logits_process import LogitsProcessorList
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
15
+ # from modules.config import path_fooocus_expansion
16
+ from ldm_patched.modules.model_patcher import ModelPatcher
17
+
18
+ path_fooocus_expansion ="extras/fooocus_expansion"
19
+ # limitation of np.random.seed(), called from transformers.set_seed()
20
+ SEED_LIMIT_NUMPY = 2**32
21
+ neg_inf = - 8192.0
22
+
23
+
24
+ def safe_str(x):
25
+ x = str(x)
26
+ for _ in range(16):
27
+ x = x.replace(' ', ' ')
28
+ return x.strip(",. \r\n")
29
+
30
+
31
+ def remove_pattern(x, pattern):
32
+ for p in pattern:
33
+ x = x.replace(p, '')
34
+ return x
35
+
36
+
37
+ class FooocusExpansion:
38
+ def __init__(self):
39
+ self.tokenizer = AutoTokenizer.from_pretrained(path_fooocus_expansion)
40
+
41
+ positive_words = open(os.path.join(path_fooocus_expansion, 'positive.txt'),
42
+ encoding='utf-8').read().splitlines()
43
+ positive_words = ['Ġ' + x.lower() for x in positive_words if x != '']
44
+
45
+ self.logits_bias = torch.zeros((1, len(self.tokenizer.vocab)), dtype=torch.float32) + neg_inf
46
+
47
+ debug_list = []
48
+ for k, v in self.tokenizer.vocab.items():
49
+ if k in positive_words:
50
+ self.logits_bias[0, v] = 0
51
+ debug_list.append(k[1:])
52
+
53
+ print(f'Fooocus V2 Expansion: Vocab with {len(debug_list)} words.')
54
+
55
+ # debug_list = '\n'.join(sorted(debug_list))
56
+ # print(debug_list)
57
+
58
+ # t11 = self.tokenizer(',', return_tensors="np")
59
+ # t198 = self.tokenizer('\n', return_tensors="np")
60
+ # eos = self.tokenizer.eos_token_id
61
+
62
+ self.model = AutoModelForCausalLM.from_pretrained(path_fooocus_expansion)
63
+ self.model.eval()
64
+
65
+ load_device = model_management.text_encoder_device()
66
+ offload_device = model_management.text_encoder_offload_device()
67
+
68
+ # MPS hack
69
+ if model_management.is_device_mps(load_device):
70
+ load_device = torch.device('cpu')
71
+ offload_device = torch.device('cpu')
72
+
73
+ use_fp16 = model_management.should_use_fp16(device=load_device)
74
+
75
+ if use_fp16:
76
+ self.model.half()
77
+
78
+ self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device)
79
+ print(f'Fooocus Expansion engine loaded for {load_device}, use_fp16 = {use_fp16}.')
80
+
81
+ @torch.no_grad()
82
+ @torch.inference_mode()
83
+ def logits_processor(self, input_ids, scores):
84
+ assert scores.ndim == 2 and scores.shape[0] == 1
85
+ self.logits_bias = self.logits_bias.to(scores)
86
+
87
+ bias = self.logits_bias.clone()
88
+ bias[0, input_ids[0].to(bias.device).long()] = neg_inf
89
+ bias[0, 11] = 0
90
+
91
+ return scores + bias
92
+
93
+ @torch.no_grad()
94
+ @torch.inference_mode()
95
+ def __call__(self, prompt, seed):
96
+ if prompt == '':
97
+ return ''
98
+
99
+ if self.patcher.current_device != self.patcher.load_device:
100
+ print('Fooocus Expansion loaded by itself.')
101
+ model_management.load_model_gpu(self.patcher)
102
+
103
+ seed = int(seed) % SEED_LIMIT_NUMPY
104
+ set_seed(seed)
105
+ prompt = safe_str(prompt) + ','
106
+
107
+ tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt")
108
+ tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device)
109
+ tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device)
110
+
111
+ current_token_length = int(tokenized_kwargs.data['input_ids'].shape[1])
112
+ max_token_length = 225 * int(math.ceil(float(current_token_length) / 225.0))
113
+ max_new_tokens = max_token_length - current_token_length
114
+
115
+ if max_new_tokens == 0:
116
+ return prompt[:-1]
117
+
118
+ # https://huggingface.co/blog/introducing-csearch
119
+ # https://huggingface.co/docs/transformers/generation_strategies
120
+ features = self.model.generate(**tokenized_kwargs,
121
+ top_k=100,
122
+ max_new_tokens=max_new_tokens,
123
+ do_sample=True,
124
+ logits_processor=LogitsProcessorList([self.logits_processor]))
125
+
126
+ response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
127
+ result = safe_str(response[0])
128
+
129
+ return result