Edit model card

Summary

Distilled with Distily library using teacher model gpt2 on dataset wikimedia/wikipedia.

Model Architecture:

  • Architecture: GPT2LMHeadModel
  • Total Parameters: 124,439,808
  • Data Type (dtype): torch.bfloat16
  • Model Size: 0.24 GB

Benchmark Metrics Comparison

Metric attn_layer_mapper=all, attn_loss_fn=cos, attn_projector=orthogonal, attn_weight=5 attn_layer_mapper=layer-2, attn_loss_fn=raw_mse, attn_projector=orthogonal, attn_weight=25.0 teacher
ai2_arc (acc) 0.313 0.305 0.354
ai2_arc (acc_norm) 0.31 0.302 0.339
arc_challenge (acc) 0.181 0.173 0.188
arc_challenge (acc_norm) 0.224 0.223 0.222
arc_easy (acc) 0.378 0.37 0.436
arc_easy (acc_norm) 0.353 0.34 0.396
boolq (acc) 0.49 0.387 0.51
cola (mcc) -0.041 0.044 0.01
glue (acc) 0.396 0.412 0.403
glue (f1) 0.516 0.451 0.529
glue (mcc) -0.041 0.044 0.01
hellaswag (acc) 0.32 0.315 0.343
hellaswag (acc_norm) 0.348 0.344 0.393
mnli (acc) 0.336 0.338 0.338
mnli_mismatch (acc) 0.343 0.351 0.346
mrpc (acc) 0.444 0.353 0.515
mrpc (f1) 0.478 0.143 0.631
qnli (acc) 0.488 0.497 0.491
qqp (acc) 0.356 0.406 0.367
qqp (f1) 0.522 0.501 0.512
rte (acc) 0.56 0.549 0.516
sst2 (acc) 0.498 0.545 0.511
wikitext (bits_per_byte) 1.118 1.127 0.98
wikitext (byte_perplexity) 2.17 2.184 1.973
wikitext (word_perplexity) 63.05 65.25 37.82
wnli (acc) 0.408 0.451 0.451

Resource Usage Comparison

  • VRAM Use: 8.2855 GB

Distillation (Teacher -> Student) Architecture Difference:

  • Architecture: GPT2LMHeadModel -> GPT2LMHeadModel
  • Total Parameters: 124,439,808 -> 124,439,808
  • Data Type (dtype): torch.bfloat16 -> torch.bfloat16
  • Model Size: 0.24 GB -> 0.24 GB
Module Diff Details


Train Dataset

Trained on 145,724,804 tokens from the wikimedia/wikipedia dataset.

  • Num Samples: 247,500
  • Subset: 20231101.en
  • Split: train

Training Objective

DistillationObjective(logits_loss_component=LossComponent(label=logits, weight=1, loss_fn=kl), attn_loss_component=LossComponent(label=attn, weight=5, loss_fn=cos, layer_mapper=all))

Hyperparameters

The following hyperparameters were used during training:

Expand
  • learning_rate: 0.0001
  • train_batch_size: 4
  • eval_batch_size: 8
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: cosine_with_min_lr
  • lr_scheduler_warmup_ratio: 0.5
  • num_epochs: 1.0
  • distillation_objective: DistillationObjective(logits_loss_component=LossComponent(label=logits, weight=1, loss_fn=kl), attn_loss_component=LossComponent(label=attn, weight=5, loss_fn=cos, layer_mapper=all))
  • train_embeddings: True
  • lr_scheduler: <torch.optim.lr_scheduler.LambdaLR object at 0x7f05c40e2050>
  • student_model_name_or_path: None
  • student_config_name_or_path: None
  • student_model_config: None
  • reinitialize_weights: None
  • copy_teacher_modules: [('lm_head', False)]
  • student_model_as_bitnet: True
  • student_model_compile: False
  • dropout: None
  • teacher_model_name_or_path: gpt2
  • teacher_load_in_8bit: False
  • teacher_load_in_4bit: False
  • teacher_model_compile: False
  • dataset_uri: wikimedia/wikipedia
  • dataset_subset: 20231101.en
  • dataset_split: train
  • dataset_column_name: text
  • dataset_sample_size: 250000
  • dataset_test_size: 0.01
  • gradient_accumulation_steps: 1
  • weight_decay: 0.0
  • max_grad_norm: 1.0
  • warmup_ratio: 0.5
  • warmup_steps: 0
  • gradient_checkpointing: True

Framework Versions

  • Distily 0.3.0
  • Transformers 4.44.0
  • Pytorch 2.3.0
  • Datasets 2.21.0
Downloads last month
5
Safetensors
Model size
124M params
Tensor type
BF16
·
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Model tree for distily/distily_multi_attn_experiment_ortho

Finetuned
(1179)
this model

Dataset used to train distily/distily_multi_attn_experiment_ortho