Spaces:
Paused
Paused
FA2
Browse files- requirements.txt +1 -0
- src/calibration_datasets.py +8 -9
- src/medusa_training_script.py +23 -10
requirements.txt
CHANGED
@@ -1 +1,2 @@
|
|
1 |
medusa-llm[train]
|
|
|
|
1 |
medusa-llm[train]
|
2 |
+
flash-attn
|
src/calibration_datasets.py
CHANGED
@@ -15,7 +15,6 @@ class CalibrationDataset(ABC):
|
|
15 |
dataset_config: dict
|
16 |
dataset: str
|
17 |
dataset_name: str
|
18 |
-
dataset_limit: int = int(1e7)
|
19 |
|
20 |
# Defines the field to extract from the HF dataset
|
21 |
# If specified, just this field will be returned, and no transformation will be done.
|
@@ -125,7 +124,7 @@ class CalibrationDataset(ABC):
|
|
125 |
|
126 |
print(f"Loading HF dataset {path} with params: {kwargs}")
|
127 |
data: Dataset = load_dataset(path=path, streaming=True, **kwargs)
|
128 |
-
return data.shuffle().take(limit)
|
129 |
|
130 |
@staticmethod
|
131 |
def list_with_nls(samples: List[str]) -> List[str]:
|
@@ -152,11 +151,11 @@ class CalibrationDataset(ABC):
|
|
152 |
"""
|
153 |
# Load HF dataset. Subclasses provide HF dataset details in `dataset_config`
|
154 |
if not self.data:
|
155 |
-
self.data = self.get_hf_dataset(**self.dataset_config, limit=self.
|
156 |
|
157 |
if not self.samples:
|
158 |
if hasattr(self, "dataset_field") and self.dataset_field:
|
159 |
-
samples =
|
160 |
else:
|
161 |
try:
|
162 |
samples = self.process_samples()
|
@@ -222,11 +221,11 @@ class WikitextDataset(CalibrationDataset):
|
|
222 |
}
|
223 |
dataset_name = "Wikitext103 Full"
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
|
231 |
|
232 |
class C4Dataset(CalibrationDataset):
|
|
|
15 |
dataset_config: dict
|
16 |
dataset: str
|
17 |
dataset_name: str
|
|
|
18 |
|
19 |
# Defines the field to extract from the HF dataset
|
20 |
# If specified, just this field will be returned, and no transformation will be done.
|
|
|
124 |
|
125 |
print(f"Loading HF dataset {path} with params: {kwargs}")
|
126 |
data: Dataset = load_dataset(path=path, streaming=True, **kwargs)
|
127 |
+
return iter(data.shuffle().take(limit))
|
128 |
|
129 |
@staticmethod
|
130 |
def list_with_nls(samples: List[str]) -> List[str]:
|
|
|
151 |
"""
|
152 |
# Load HF dataset. Subclasses provide HF dataset details in `dataset_config`
|
153 |
if not self.data:
|
154 |
+
self.data = self.get_hf_dataset(**self.dataset_config, limit=self.num_samples*10)
|
155 |
|
156 |
if not self.samples:
|
157 |
if hasattr(self, "dataset_field") and self.dataset_field:
|
158 |
+
samples = [data[self.dataset_field] for data in self.data]
|
159 |
else:
|
160 |
try:
|
161 |
samples = self.process_samples()
|
|
|
221 |
}
|
222 |
dataset_name = "Wikitext103 Full"
|
223 |
|
224 |
+
def process_samples(self) -> List[str]:
|
225 |
+
return [
|
226 |
+
"\n" if len(item) == 0 else item
|
227 |
+
for item in self.data["text"]
|
228 |
+
]
|
229 |
|
230 |
|
231 |
class C4Dataset(CalibrationDataset):
|
src/medusa_training_script.py
CHANGED
@@ -192,16 +192,29 @@ def train():
|
|
192 |
)
|
193 |
|
194 |
# Load model and tokenizer
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
# Freeze the base model
|
207 |
for param in model.base_model.parameters():
|
|
|
192 |
)
|
193 |
|
194 |
# Load model and tokenizer
|
195 |
+
try: # Try loading with FA2
|
196 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
197 |
+
model_args.model_name_or_path,
|
198 |
+
config=config,
|
199 |
+
cache_dir=training_args.cache_dir,
|
200 |
+
low_cpu_mem_usage=True,
|
201 |
+
torch_dtype=torch.bfloat16,
|
202 |
+
quantization_config=quantization_config if model_args.load_in_4bit else None,
|
203 |
+
load_in_4bit=model_args.load_in_4bit,
|
204 |
+
load_in_8bit=model_args.load_in_8bit,
|
205 |
+
attn_implementation="flash_attention_2",
|
206 |
+
)
|
207 |
+
except:
|
208 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
209 |
+
model_args.model_name_or_path,
|
210 |
+
config=config,
|
211 |
+
cache_dir=training_args.cache_dir,
|
212 |
+
low_cpu_mem_usage=True,
|
213 |
+
torch_dtype=torch.bfloat16,
|
214 |
+
quantization_config=quantization_config if model_args.load_in_4bit else None,
|
215 |
+
load_in_4bit=model_args.load_in_4bit,
|
216 |
+
load_in_8bit=model_args.load_in_8bit,
|
217 |
+
)
|
218 |
|
219 |
# Freeze the base model
|
220 |
for param in model.base_model.parameters():
|