|
--- |
|
library_name: transformers |
|
tags: [] |
|
--- |
|
|
|
# Model Card for Model ID |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
Lora weights after finetuning Switch-base-32 on WMT16 datasets. |
|
|
|
# To use Lora weights |
|
|
|
```python |
|
from peft import PeftModel |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
base_model = AutoModelForSeq2SeqLM.from_pretrained('google/switch-base-32', use_safetensors=False) |
|
lora_model = PeftModel.from_pretrained(base_model, 'marsggbo/wmt-switch-base-32-lora-ckpt140000') |
|
merged_model = lora_model.merge_and_unload() |
|
merged_model.save_pretrained('./switch-base-32-wmt', state_dict=merged_model.state_dict(), safe_serialization=False) |
|
``` |
|
|
|
# To use merged model weights |
|
|
|
```python |
|
model = AutoModelForSeq2SeqLM.from_pretrained('./switch-base-32-wmt') |
|
``` |