|
|
|
"""Converts a Whisper model in Hugging Face format to OpenAI format. |
|
This script is based on the following script to do the opposite: |
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/convert_openai_to_hf.py |
|
Requirements: |
|
```bash |
|
pip install -U openai-whisper |
|
``` |
|
Example: |
|
```bash |
|
# Converts the model from Hugging Face to OpenAI format: |
|
python convert_hf_to_openai.py \ |
|
--checkpoint openai/whisper-tiny \ |
|
--whisper_dump_path whisper-tiny-openai.pt |
|
``` |
|
```python |
|
>>> # Disabled doctest because it requries the openai-whisper package. |
|
>> import whisper |
|
>> from transformers.models.whisper.convert_hf_to_openai import convert_tfms_to_openai_whisper |
|
>> # Converts the model from Hugging Face to OpenAI format: |
|
>> convert_tfms_to_openai_whisper( |
|
.. "openai/whisper-tiny", "whisper-tiny-openai.pt" |
|
.. ) |
|
HF model path: openai/whisper-tiny |
|
OpenAI model path: whisper-tiny-openai.pt |
|
>> # Select an audio file: |
|
>> audio_path = "https://huggingface.co/datasets/sanchit-gandhi/librispeech_long/resolve/main/audio.wav" |
|
>> # Load the Whisper model in OpenAI format: |
|
>> model = whisper.load_model("whisper-tiny-openai.pt") |
|
>> # Transcribe the audio: |
|
>> prediction = model.transcribe(audio_path) |
|
>> prediction["text"][:70] |
|
' chapter 16. I might have told you of the beginning of this liaison in' |
|
``` |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from transformers import WhisperConfig, WhisperForConditionalGeneration |
|
|
|
|
|
|
|
|
|
REVERSE_WHISPER_MAPPING = { |
|
"layers": "blocks", |
|
"fc1": "mlp.0", |
|
"fc2": "mlp.2", |
|
"final_layer_norm": "mlp_ln", |
|
".self_attn.q_proj": ".attn.query", |
|
".self_attn.k_proj": ".attn.key", |
|
".self_attn.v_proj": ".attn.value", |
|
".self_attn_layer_norm": ".attn_ln", |
|
".self_attn.out_proj": ".attn.out", |
|
".encoder_attn.q_proj": ".cross_attn.query", |
|
".encoder_attn.k_proj": ".cross_attn.key", |
|
".encoder_attn.v_proj": ".cross_attn.value", |
|
".encoder_attn_layer_norm": ".cross_attn_ln", |
|
".encoder_attn.out_proj": ".cross_attn.out", |
|
"decoder.layer_norm.": "decoder.ln.", |
|
"encoder.layer_norm.": "encoder.ln_post.", |
|
"embed_tokens": "token_embedding", |
|
"encoder.embed_positions.weight": "encoder.positional_embedding", |
|
"decoder.embed_positions.weight": "decoder.positional_embedding", |
|
} |
|
|
|
|
|
def reverse_rename_keys(s_dict: dict) -> dict: |
|
"""Renames the keys back from Hugging Face to OpenAI Whisper format. |
|
By using this function on an HF model's state_dict, we should get the names in the format expected by Whisper. |
|
Args: |
|
s_dict (`dict`): A dictionary with keys in Hugging Face format. |
|
Returns: |
|
`dict`: The same dictionary but in OpenAI Whisper format. |
|
""" |
|
keys = list(s_dict.keys()) |
|
for orig_key in keys: |
|
new_key = orig_key |
|
for key_r, value_r in REVERSE_WHISPER_MAPPING.items(): |
|
if key_r in orig_key: |
|
new_key = new_key.replace(key_r, value_r) |
|
|
|
|
|
|
|
s_dict[new_key] = s_dict.pop(orig_key) |
|
return s_dict |
|
|
|
|
|
def make_emb_from_linear(linear: nn.Linear) -> nn.Embedding: |
|
"""Converts a linear layer's weights into an embedding layer. |
|
The linear layer's `in_features` dimension corresponds to the vocabulary size and its `out_features` dimension |
|
corresponds to the embedding size. |
|
Args: |
|
linear (`nn.Linear`): The linear layer to be converted. |
|
Returns: |
|
`nn.Embedding`: |
|
An embedding layer with weights set to those of the input linear layer. |
|
""" |
|
vocab_size, emb_size = linear.weight.data.shape |
|
emb_layer = nn.Embedding(vocab_size, emb_size, _weight=linear.weight.data) |
|
return emb_layer |
|
|
|
|
|
def extract_dims_from_hf(config: WhisperConfig) -> dict: |
|
"""Extracts necessary dimensions from Hugging Face's WhisperConfig. |
|
Extracts necessary dimensions and related configuration data from the Hugging Face model and then restructure it |
|
for the OpenAI Whisper format. |
|
Args: |
|
config (`WhisperConfig`): Configuration of the Hugging Face's model. |
|
Returns: |
|
`dict`: The `dims` of the OpenAI Whisper model. |
|
""" |
|
dims = { |
|
"n_vocab": config.vocab_size, |
|
"n_mels": config.num_mel_bins, |
|
"n_audio_state": config.d_model, |
|
"n_text_ctx": config.max_target_positions, |
|
"n_audio_layer": config.encoder_layers, |
|
"n_audio_head": config.encoder_attention_heads, |
|
"n_text_layer": config.decoder_layers, |
|
"n_text_head": config.decoder_attention_heads, |
|
"n_text_state": config.d_model, |
|
"n_audio_ctx": config.max_source_positions, |
|
} |
|
return dims |
|
|
|
|
|
def convert_tfms_to_openai_whisper(hf_model_path: str, whisper_dump_path: str): |
|
"""Converts a Whisper model from the Hugging Face to the OpenAI format. |
|
Takes in the path to a Hugging Face Whisper model, extracts its state_dict, renames keys as needed, and then saves |
|
the model OpenAI's format. |
|
Args: |
|
hf_model_path (`str`): |
|
Path to the pretrained Whisper model in Hugging Face format. |
|
whisper_dump_path (`str`): |
|
Destination path where the converted model in Whisper/OpenAI format will be saved. |
|
Returns: |
|
`None` |
|
""" |
|
print("HF model path:", hf_model_path) |
|
print("OpenAI model path:", whisper_dump_path) |
|
|
|
|
|
model = WhisperForConditionalGeneration.from_pretrained(hf_model_path) |
|
state_dict = model.state_dict() |
|
|
|
|
|
state_dict = reverse_rename_keys(state_dict) |
|
|
|
|
|
dims = extract_dims_from_hf(model.config) |
|
|
|
|
|
del state_dict["proj_out.weight"] |
|
|
|
|
|
state_dict = {k.replace("model.", "", 1): v for k, v in state_dict.items()} |
|
whisper_checkpoint = {"dims": dims, "model_state_dict": state_dict} |
|
|
|
|
|
torch.save(whisper_checkpoint, whisper_dump_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--checkpoint", |
|
type=str, |
|
help="Path of name of the Hugging Face checkpoint.", |
|
) |
|
parser.add_argument( |
|
"--whisper_dump_path", |
|
type=str, |
|
help="Path to the output Whisper model.", |
|
) |
|
args = parser.parse_args() |
|
|
|
convert_tfms_to_openai_whisper(args.checkpoint, args.whisper_dump_path) |
|
|