{ "cells": [ { "cell_type": "code", "execution_count": 9, "id": "bb4dd66b-0c17-48d4-9d34-f48cece2feb5", "metadata": {}, "outputs": [], "source": [ "# !pip install soundfile\n", "# !pip install librosa" ] }, { "cell_type": "code", "execution_count": 1, "id": "6e9386ea-4862-4f5b-a02f-d656e1a5ab9e", "metadata": {}, "outputs": [], "source": [ "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n", "from datasets import load_dataset" ] }, { "cell_type": "code", "execution_count": 2, "id": "914ab2b4-389d-4c48-8d1d-1250356646ac", "metadata": {}, "outputs": [], "source": [ "# load model and processor\n", "processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n", "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n", "model.config.forced_decoder_ids = None\n", "\n", "# load dummy dataset and read audio files\n", "ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n", "sample = ds[0][\"audio\"]" ] }, { "cell_type": "code", "execution_count": 3, "id": "2b299bab-1228-48d9-a8a5-3d5b6c52162d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'path': '/home/gunak/.cache/huggingface/datasets/downloads/extracted/431c2c946d216530b2666a0e7ffa5ac3f5b3da89dd28858a9de6c78fae7caa4a/dev_clean/1272/128104/1272-128104-0000.flac',\n", " 'array': array([0.00238037, 0.0020752 , 0.00198364, ..., 0.00042725, 0.00057983,\n", " 0.0010376 ]),\n", " 'sampling_rate': 16000}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample" ] }, { "cell_type": "code", "execution_count": 4, "id": "b7e570a1-cf5c-450c-a7b6-49b45a10d2df", "metadata": {}, "outputs": [], "source": [ "input_features = processor(sample[\"array\"], sampling_rate=sample[\"sampling_rate\"], return_tensors=\"pt\").input_features " ] }, { "cell_type": "code", "execution_count": 5, "id": "584e920b-a7fd-402d-95dd-3b9128cd34bb", "metadata": {}, "outputs": [], "source": [ "# generate token ids\n", "predicted_ids = model.generate(input_features)\n", "# decode token ids to text\n", "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)\n", "\n", "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)" ] }, { "cell_type": "code", "execution_count": 6, "id": "b27ab660-861b-49d1-81f9-f51cb7f9d8d8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.']" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transcription" ] }, { "cell_type": "code", "execution_count": 3, "id": "eca553b8-68f6-493d-b567-3d526b49ae1b", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn" ] }, { "cell_type": "code", "execution_count": 4, "id": "c619a4cf-9068-4e4d-8139-e16d15345f4f", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer" ] }, { "cell_type": "code", "execution_count": 5, "id": "47d5b1ff-ab0f-4d11-af64-d2fa2be39286", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } ], "source": [ "model_name = \"microsoft/phi-2\"\n", "phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n", "phi2_tokenizer.pad_token = phi2_tokenizer.eos_token" ] }, { "cell_type": "code", "execution_count": 6, "id": "0b36b3f0-db5b-4029-9072-0a53bcab315a", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'transcription' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m tokens \u001b[38;5;241m=\u001b[39m phi2_tokenizer(\u001b[38;5;241m*\u001b[39m\u001b[43mtranscription\u001b[49m, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m, return_attention_mask\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", "\u001b[0;31mNameError\u001b[0m: name 'transcription' is not defined" ] } ], "source": [ "tokens = phi2_tokenizer(*transcription, return_tensors=\"pt\", return_attention_mask=False)" ] }, { "cell_type": "code", "execution_count": 22, "id": "91f6d3d3-bb00-434f-a91e-6952375890d0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input_ids': tensor([[ 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262,\n", " 3504, 6097, 290, 356, 389, 9675, 284, 7062, 465, 21443,\n", " 13]])}" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokens" ] }, { "cell_type": "code", "execution_count": 12, "id": "533191d9-4b3b-417a-918d-6fe854f24b50", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:\n", "- configuration_phi.py\n", ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2a65a119388b4cb4b123b532176e786e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "modeling_phi.py: 0%| | 0.00/62.7k [00:00 {text} \"\n", " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n", " return tokens\n", " \n", "\n", "class WhisperWithProjection:\n", " def __init__(self):\n", " self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n", " self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n", " self.model.config.forced_decoder_ids = None\n", " self.audio_language_connector = AudioLanguageConnector()\n", " \n", " def forward(self, audio):\n", " input_features = self.processor(audio[\"array\"],\n", " sampling_rate=audio[\"sampling_rate\"],\n", " return_tensors=\"pt\").input_features\n", " # generate token ids\n", " predicted_ids = self.model.generate(input_features)\n", " # decode token ids to text \n", " transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n", "\n", " audio_embeddings = self.audio_language_connector(transcription)\n", " return audio_embeddings" ] }, { "cell_type": "code", "execution_count": 8, "id": "2b1f8f44-bfe6-413c-9e32-c38fa5517981", "metadata": {}, "outputs": [], "source": [ "class TextModality:\n", " def __init__(self):\n", " model_name = \"microsoft/phi-2\"\n", " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n", " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n", "\n", " def __call__(self, text):\n", " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n", " return tokens" ] }, { "cell_type": "code", "execution_count": 15, "id": "21c51648-abb6-4bbd-b4c1-509967a69337", "metadata": {}, "outputs": [], "source": [ "class MultiModalPhi2:\n", " def __init__(self):\n", " self.text_modality = TextModality()\n", " self.whisper_w_proj = WhisperWithProjection()\n", " self.llm = self.load_llm()\n", "\n", " def load_llm(self):\n", " bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=torch.float16)\n", " \n", " model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " quantization_config=bnb_config,\n", " trust_remote_code=True,\n", " device_map=\"cuda:0\"\n", " )\n", " model.config.use_cache = False\n", " return model\n", "\n", " def generate(self, audio, text):\n", " text_embeddings = self.text_modality(text)\n", " audio_embeddings = self.whisper_w_proj.forward(audio)\n", " inputs = torch.concat([text_embeddings[\"input_ids\"], audio_embeddings[\"input_ids\"]], dim=1)\n", " \n", " # outputs = self.llm.generate(inputs, max_length=200)\n", " outputs = self.llm(inputs)\n", " return outputs\n", " \n", " # text = self.text_modality.phi2_tokenizer.batch_decode(outputs)[0]\n", " # print(text)" ] }, { "cell_type": "code", "execution_count": 16, "id": "472a00cb-bae9-4c09-a0ef-bc57881b5e2c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2236e6b1e26d444fa3d48181ba1a6cf9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00)}, logits=tensor([[[ 6.9531, 9.9375, 7.0234, ..., 2.0020, 2.0020, 2.0000],\n", " [ 8.9062, 12.1172, 7.5977, ..., -1.2012, -1.2012, -1.2012],\n", " [ 7.0273, 5.3477, 3.6328, ..., -4.2070, -4.2070, -4.2070],\n", " ...,\n", " [ 7.0234, 7.4414, 9.1016, ..., 1.0117, 1.0127, 1.0117],\n", " [ 9.4531, 10.0391, 9.7578, ..., 0.0776, 0.0775, 0.0764],\n", " [ 8.0703, 6.6445, 5.5156, ..., -1.9268, -1.9268, -1.9277]]],\n", " grad_fn=), past_key_values=None, hidden_states=None, attentions=None)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "audio = sample\n", "text = \"explain about the audio\"\n", "multi_modal_phi.generate(audio, text)" ] }, { "cell_type": "code", "execution_count": null, "id": "46aa9c66-a5bb-4760-8895-92673f49345f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }