diff --git "a/lecture-llama-2-7B-food-order-understanding (1) (1).ipynb" "b/lecture-llama-2-7B-food-order-understanding (1) (1).ipynb" new file mode 100644--- /dev/null +++ "b/lecture-llama-2-7B-food-order-understanding (1) (1).ipynb" @@ -0,0 +1 @@ +{"cells":[{"cell_type":"markdown","id":"28e4c4d1-a73f-437b-a1bd-c2cc3874924a","metadata":{"id":"28e4c4d1-a73f-437b-a1bd-c2cc3874924a"},"source":["# 강의 11주차: llama2-food-order-understanding\n","\n","1. llama-2-7b-chat-hf 를 주문 문장 이해에 미세 튜닝\n","\n","- food-order-understanding-small-3200.json (학습)\n","- food-order-understanding-small-800.json (검증)\n","\n","\n","종속적인 필요 내용\n","- huggingface 계정 설정 및 llama-2 사용 승인\n","- 로깅을 위한 wandb"]},{"cell_type":"code","source":["pip install transformers peft accelerate optimum bitsandbytes trl wandb"],"metadata":{"id":"nDZe_wqKU6J3","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1702298092542,"user_tz":-540,"elapsed":28936,"user":{"displayName":"조수연","userId":"02851953064562485279"}},"outputId":"b1b9d5bf-2b71-4974-fb52-304fa901564a"},"id":"nDZe_wqKU6J3","execution_count":1,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.2)\n","Collecting peft\n"," Downloading peft-0.7.0-py3-none-any.whl (168 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.3/168.3 kB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting accelerate\n"," Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m265.7/265.7 kB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting optimum\n"," Downloading optimum-1.15.0-py3-none-any.whl (400 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m400.9/400.9 kB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting bitsandbytes\n"," Downloading bitsandbytes-0.41.3.post1-py3-none-any.whl (92.6 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.6/92.6 MB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting trl\n"," Downloading trl-0.7.4-py3-none-any.whl (133 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m133.9/133.9 kB\u001b[0m \u001b[31m22.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting wandb\n"," Downloading wandb-0.16.1-py3-none-any.whl (2.1 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m101.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.4)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n","Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.0)\n","Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.1)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n","Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.1.0+cu118)\n","Collecting coloredlogs (from optimum)\n"," Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from optimum) (1.12)\n","Collecting datasets (from optimum)\n"," Downloading datasets-2.15.0-py3-none-any.whl (521 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m521.2/521.2 kB\u001b[0m \u001b[31m57.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting tyro>=0.5.11 (from trl)\n"," Downloading tyro-0.6.0-py3-none-any.whl (100 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.9/100.9 kB\u001b[0m \u001b[31m17.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n","Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)\n"," Downloading GitPython-3.1.40-py3-none-any.whl (190 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.6/190.6 kB\u001b[0m \u001b[31m30.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting sentry-sdk>=1.0.0 (from wandb)\n"," Downloading sentry_sdk-1.38.0-py2.py3-none-any.whl (252 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m252.8/252.8 kB\u001b[0m \u001b[31m33.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting docker-pycreds>=0.4.0 (from wandb)\n"," Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n","Collecting setproctitle (from wandb)\n"," Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n","Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n","Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)\n","Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n","Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n","Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb)\n"," Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.6.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.11.17)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.2.1)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n","Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.1.0)\n","Collecting sentencepiece!=0.1.92,>=0.1.91 (from transformers)\n"," Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m60.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting docstring-parser>=0.14.1 (from tyro>=0.5.11->trl)\n"," Downloading docstring_parser-0.15-py3-none-any.whl (36 kB)\n","Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (13.7.0)\n","Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)\n"," Downloading shtab-1.6.5-py3-none-any.whl (13 kB)\n","Collecting humanfriendly>=9.1 (from coloredlogs->optimum)\n"," Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m13.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (9.0.0)\n","Collecting pyarrow-hotfix (from datasets->optimum)\n"," Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n","Collecting dill<0.3.8,>=0.3.0 (from datasets->optimum)\n"," Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m16.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (1.5.3)\n","Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (3.4.1)\n","Collecting multiprocess (from datasets->optimum)\n"," Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (3.9.1)\n","Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->optimum) (1.3.0)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (23.1.0)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (6.0.4)\n","Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.9.3)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.4.0)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.3.1)\n","Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (4.0.3)\n","Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb)\n"," Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n","Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n","Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n","Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->optimum) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->optimum) (2023.3.post1)\n","Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n","Installing collected packages: sentencepiece, bitsandbytes, smmap, shtab, setproctitle, sentry-sdk, pyarrow-hotfix, humanfriendly, docstring-parser, docker-pycreds, dill, multiprocess, gitdb, coloredlogs, tyro, GitPython, accelerate, wandb, datasets, trl, peft, optimum\n","Successfully installed GitPython-3.1.40 accelerate-0.25.0 bitsandbytes-0.41.3.post1 coloredlogs-15.0.1 datasets-2.15.0 dill-0.3.7 docker-pycreds-0.4.0 docstring-parser-0.15 gitdb-4.0.11 humanfriendly-10.0 multiprocess-0.70.15 optimum-1.15.0 peft-0.7.0 pyarrow-hotfix-0.6 sentencepiece-0.1.99 sentry-sdk-1.38.0 setproctitle-1.3.3 shtab-1.6.5 smmap-5.0.1 trl-0.7.4 tyro-0.6.0 wandb-0.16.1\n"]}]},{"cell_type":"code","execution_count":2,"id":"51eb00d7-2928-41ad-9ae9-7f0da7d64d6d","metadata":{"id":"51eb00d7-2928-41ad-9ae9-7f0da7d64d6d","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1702298115386,"user_tz":-540,"elapsed":19506,"user":{"displayName":"조수연","userId":"02851953064562485279"}},"outputId":"287a72a9-e173-419f-d057-3b06d7c1dce4"},"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n"," warnings.warn(\n"]}],"source":["import os\n","from dataclasses import dataclass, field\n","from typing import Optional\n","import re\n","\n","import torch\n","import tyro\n","from accelerate import Accelerator\n","from datasets import load_dataset, Dataset\n","from peft import AutoPeftModelForCausalLM, LoraConfig\n","from tqdm import tqdm\n","from transformers import (\n"," AutoModelForCausalLM,\n"," AutoTokenizer,\n"," BitsAndBytesConfig,\n"," TrainingArguments,\n",")\n","\n","from trl import SFTTrainer\n","\n","from trl.trainer import ConstantLengthDataset"]},{"cell_type":"code","source":["from huggingface_hub import notebook_login\n","\n","notebook_login()"],"metadata":{"id":"tX7gYxZaVhYL","colab":{"base_uri":"https://localhost:8080/","height":145,"referenced_widgets":["22e8a9d9cda84d92b1813f8929f4eea0","657a33fb2b544f85930194969feed3c4","8e305c09fb884c35b4eff11a101541ea","f23410ca73f04660ac6428a3eb2f0487","6c4553801fb54d9495d5ee3636de9ce4","1565660880f4460993bceaf233525be8","36948d25a53d4084a82c91fb37e1c1b8","e47ec0f7ec7d45708dfcf38ba066c84e","8b1aa9a9d061465f932b30ac87495e3b","f76f0c3de150487cb028fac16e3638cb","1a24b471df0c4e0b82f0394d4bbda2d8","9e29fc343f1143158f8cd0a42394f133","312b7ba4d0d34e8c98f5737b35fc0e75","feed29fb48db4887b5e39a3662efa4eb","84ead8c6a8f34b9a98f6dc7e48544bf7","5ebfe7fb5c914c36b47e501bccb8c4d5","ef576e415f4045c6a66bc5056f169b33","4f4a174309c94b1faafe6f6bbff4f08e","81d2e99eeece47c2a0f27d915b55afee","7d2e6f866f074889901fd7093e0bb81a","61f2d328438c4df987de5376b9e48bb4","86e6cb6b34a34a19846b6452a1cf3d11","04a8cbf69f7b402188a76dd560c38b44","2b1cbbfb38954d1ab819b564e5db29c6","a54ab155dddf45fdaaeefdf1d574613d","8d583097f2ec413fa25de9c5f1ecfcfd","845eb52754a549238112bae5c0c1b751","e787c4ef3a3e4306876093df43d5fa8d","7ef0269e33204383bf4213ac14c048a2","6e052d1117ae48e6b469b0051d5a8e13","4d84f767af3d45af94b72234127f4788","c61999ab9f194d4a89ec849308f4eaa4"]},"executionInfo":{"status":"ok","timestamp":1702298118857,"user_tz":-540,"elapsed":347,"user":{"displayName":"조수연","userId":"02851953064562485279"}},"outputId":"74676d5a-7075-4a3e-d19a-33e2df450021"},"id":"tX7gYxZaVhYL","execution_count":3,"outputs":[{"output_type":"display_data","data":{"text/plain":["VBox(children=(HTML(value='
Step | \n","Training Loss | \n","
---|---|
100 | \n","0.268500 | \n","
200 | \n","0.274100 | \n","
300 | \n","0.278600 | \n","
400 | \n","0.270100 | \n","
500 | \n","0.271700 | \n","
600 | \n","0.268900 | \n","
700 | \n","0.263000 | \n","
"],"text/plain":[" "]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":["TrainOutput(global_step=852, training_loss=0.26882281997394114, metrics={'train_runtime': 3234.9061, 'train_samples_per_second': 0.989, 'train_steps_per_second': 0.495, 'total_flos': 3.469201579494605e+16, 'train_loss': 0.26882281997394114, 'epoch': 0.53})"]},"metadata":{},"execution_count":27}],"source":["trainer.train()"]},{"cell_type":"code","source":["script_args.training_args.output_dir"],"metadata":{"id":"3Y4FQSyRghQt","colab":{"base_uri":"https://localhost:8080/","height":35},"executionInfo":{"status":"ok","timestamp":1702305413065,"user_tz":-540,"elapsed":574,"user":{"displayName":"조수연","userId":"02851953064562485279"}},"outputId":"aea190d1-e8b6-4cd0-c020-11b748fdfd8f"},"id":"3Y4FQSyRghQt","execution_count":28,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'/gdrive/MyDrive/lora-llama-2-7b-food-order-understanding'"],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":28}]},{"cell_type":"code","execution_count":29,"id":"49f05450-da2a-4edd-9db2-63836a0ec73a","metadata":{"id":"49f05450-da2a-4edd-9db2-63836a0ec73a","executionInfo":{"status":"ok","timestamp":1702305461462,"user_tz":-540,"elapsed":476,"user":{"displayName":"조수연","userId":"02851953064562485279"}}},"outputs":[],"source":["trainer.save_model(script_args.training_args.output_dir)"]},{"cell_type":"markdown","id":"652f307e-e1d7-43ae-b083-dba2d94c2296","metadata":{"id":"652f307e-e1d7-43ae-b083-dba2d94c2296"},"source":["# 추론 테스트"]},{"cell_type":"code","execution_count":30,"id":"ea8a1fea-7499-4386-9dea-0509110f61af","metadata":{"id":"ea8a1fea-7499-4386-9dea-0509110f61af","executionInfo":{"status":"ok","timestamp":1702305464590,"user_tz":-540,"elapsed":1618,"user":{"displayName":"조수연","userId":"02851953064562485279"}}},"outputs":[],"source":["from transformers import pipeline, TextStreamer"]},{"cell_type":"code","execution_count":31,"id":"52626888-1f6e-46b6-a8dd-836622149ff5","metadata":{"id":"52626888-1f6e-46b6-a8dd-836622149ff5","executionInfo":{"status":"ok","timestamp":1702305466602,"user_tz":-540,"elapsed":460,"user":{"displayName":"조수연","userId":"02851953064562485279"}}},"outputs":[],"source":["instruction_prompt_template = \"\"\"###System;다음은 매장에서 고객이 음식을 주문하는 주문 문장이다. 이를 분석하여 음식명, 옵션명, 수량을 추출하여 고객의 의도를 이해하고자 한다.\n","분석 결과를 완성해주기 바란다.\n","\n","### 주문 문장: {0} ### 분석 결과:\n","\"\"\"\n","\n","prompt_template = \"\"\"###System;{System}\n","###User;{User}\n","###Midm;\"\"\"\n","\n","default_system_msg = (\n"," \"너는 먼저 사용자가 입력한 주문 문장을 분석하는 에이전트이다. 이로부터 주문을 구성하는 음식명, 옵션명, 수량을 차례대로 추출해야 한다.\"\n",")"]},{"cell_type":"code","execution_count":32,"id":"46e844fa-8f63-4359-a4fb-df66e8171796","metadata":{"id":"46e844fa-8f63-4359-a4fb-df66e8171796","executionInfo":{"status":"ok","timestamp":1702305469440,"user_tz":-540,"elapsed":316,"user":{"displayName":"조수연","userId":"02851953064562485279"}}},"outputs":[],"source":["evaluation_queries = [\n"," \"오늘은 비가오니깐 이거 먹자. 삼선짬뽕 곱배기 하나하구요, 사천 탕수육 중짜 한그릇 주세요.\",\n"," \"아이스아메리카노 톨사이즈 한잔 하고요. 딸기스무디 한잔 주세요. 또, 콜드브루라떼 하나요.\",\n"," \"참이슬 한병, 코카콜라 1.5리터 한병, 테슬라 한병이요.\",\n"," \"꼬막무침 1인분하고요, 닭도리탕 중자 주세요. 그리고 소주도 한병 주세요.\",\n"," \"김치찌개 3인분하고요, 계란말이 주세요.\",\n"," \"불고기버거세트 1개하고요 감자튀김 추가해주세요.\",\n"," \"불닭볶음면 1개랑 사리곰탕면 2개 주세요.\",\n"," \"카페라떼 아이스 샷추가 한잔하구요. 스콘 하나 주세요\",\n"," \"여기요 춘천닭갈비 4인분하고요. 라면사리 추가하겠습니다. 콜라 300ml 두캔주세요.\",\n"," \"있잖아요 조랭이떡국 3인분하고요. 떡만두 한세트 주세요.\",\n"," \"깐풍탕수 2인분 하고요 콜라 1.5리터 한병이요.\",\n","]"]},{"cell_type":"code","execution_count":33,"id":"1919cf1f-482e-4185-9d06-e3cea1918416","metadata":{"id":"1919cf1f-482e-4185-9d06-e3cea1918416","executionInfo":{"status":"ok","timestamp":1702305485661,"user_tz":-540,"elapsed":364,"user":{"displayName":"조수연","userId":"02851953064562485279"}}},"outputs":[],"source":["def wrapper_generate(model, input_prompt):\n"," data = tokenizer(input_prompt, return_tensors=\"pt\")\n"," streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n"," input_ids = data.input_ids[..., :-1]\n"," with torch.no_grad():\n"," pred = model.generate(\n"," input_ids=input_ids.cuda(),\n"," streamer=streamer,\n"," use_cache=True,\n"," max_new_tokens=float('inf'),\n"," temperature=0.5\n"," )\n"," decoded_text = tokenizer.batch_decode(pred, skip_special_tokens=True)\n"," return (decoded_text[0][len(input_prompt):])"]},{"cell_type":"code","execution_count":34,"id":"eaac1f6f-c823-4488-8edb-2f931ddf0daa","metadata":{"id":"eaac1f6f-c823-4488-8edb-2f931ddf0daa","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1702305643417,"user_tz":-540,"elapsed":156747,"user":{"displayName":"조수연","userId":"02851953064562485279"}},"outputId":"c575755a-0edd-42f5-a282-6548dc600566"},"outputs":[{"output_type":"stream","name":"stdout","text":[";- 분석 결과 0: 음식명:삼선짬뽕,옵션:곱배기,수량:하나\n","- 분석 결과 1: 음식명:사천 탕수육,옵션:중짜,수량:한그릇\n",";- 분석 결과 0: 음식명:아이스아메리카노,옵션:톨,수량:한잔\n","- 분석 결과 1: 음식명:딸기스무디,수량:한잔\n","- 분석 결과 2: 음식명:콜드브루라떼,수량:하나\n",";- 분석 결과 0: 음식명:참이슬, 수량:한병\n","- 분석 결과 1: 음식명:코카콜라, 옵션:1.5리터, 수량:한병\n","- 분석 결과 2: 음식명:테슬라, 수량:한병\n",";- 분석 결과 0: 음식명:꼬막무침, 수량:1인분\n","- 분석 결과 1: 음식명:닭도리탕, 옵션:중자\n","- 분석 결과 2: 음식명:소주, 수량:한병\n",";- 분석 결과 0: 음식명:김치찌개,수량:3인분\n","- 분석 결과 1: 음식명:계란말이\n",";- 분석 결과 0: 음식명:불고기버거세트, 수량:1개\n","- 분석 결과 1: 음식명:감자튀김\n",";- 분석 결과 0: 음식명:불닭볶음면, 수량:1개\n","- 분석 결과 1: 음식명:사리곰탕면, 수량:2개\n",";- 분석 결과 0: 음식명:카페라떼,옵션:아이스,샷추가,수량:한잔\n","- 분석 결과 1: 음식명:스콘,수량:하나\n",";- 분석 결과 0: 음식명:춘천닭갈비, 수량:4인분\n","- 분석 결과 1: 음식명:라면사리\n","- 분석 결과 2: 음식명:콜라, 옵션:300ml, 수량:두캔\n",";- 분석 결과 0: 음식명:조랭이떡국,수량:3인분\n","- 분석 결과 1: 음식명:떡만두,수량:한세트\n",";- 분석 결과 0: 음식명:깐풍탕수, 수량:2인분\n","- 분석 결과 1: 음식명:콜라, 옵션:1.5리터, 수량:한병\n"]}],"source":["eval_dic = {i:wrapper_generate(model=base_model, input_prompt=prompt_template.format(System=default_system_msg, User=evaluation_queries[i]))for i, query in enumerate(evaluation_queries)}"]},{"cell_type":"code","execution_count":35,"id":"fefd04ba-2ed8-4f84-bdd0-86d52b3f39f6","metadata":{"id":"fefd04ba-2ed8-4f84-bdd0-86d52b3f39f6","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1702305654864,"user_tz":-540,"elapsed":637,"user":{"displayName":"조수연","userId":"02851953064562485279"}},"outputId":"27f78464-41e0-4990-8343-68f13c0e8592"},"outputs":[{"output_type":"stream","name":"stdout","text":["- 분석 결과 0: 음식명:삼선짬뽕,옵션:곱배기,수량:하나\n","- 분석 결과 1: 음식명:사천 탕수육,옵션:중짜,수량:한그릇\n"]}],"source":["print(eval_dic[0])"]},{"cell_type":"markdown","id":"3f471e3a-723b-4df5-aa72-46f571f6bab6","metadata":{"id":"3f471e3a-723b-4df5-aa72-46f571f6bab6"},"source":["# 미세튜닝된 모델 로딩 후 테스트"]},{"cell_type":"code","execution_count":36,"id":"a43bdd07-7555-42b2-9888-a614afec892f","metadata":{"id":"a43bdd07-7555-42b2-9888-a614afec892f","executionInfo":{"status":"ok","timestamp":1702305663863,"user_tz":-540,"elapsed":348,"user":{"displayName":"조수연","userId":"02851953064562485279"}}},"outputs":[],"source":["bnb_config = BitsAndBytesConfig(\n"," load_in_4bit=True,\n"," bnb_4bit_quant_type=\"nf4\",\n"," bnb_4bit_compute_dtype=torch.bfloat16,\n",")"]},{"cell_type":"code","execution_count":37,"id":"39db2ee4-23c8-471f-89b2-bca34964bf81","metadata":{"id":"39db2ee4-23c8-471f-89b2-bca34964bf81","colab":{"base_uri":"https://localhost:8080/","height":531},"executionInfo":{"status":"error","timestamp":1702305668066,"user_tz":-540,"elapsed":1963,"user":{"displayName":"조수연","userId":"02851953064562485279"}},"outputId":"097a1b44-da71-4f28-d2f8-03a15d0e3aae"},"outputs":[{"output_type":"error","ename":"ValueError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\n"," \n","
\n"," \n"," \n"," \n"," Step \n"," Training Loss \n"," \n"," \n"," 100 \n"," 0.268500 \n"," \n"," \n"," 200 \n"," 0.274100 \n"," \n"," \n"," 300 \n"," 0.278600 \n"," \n"," \n"," 400 \n"," 0.270100 \n"," \n"," \n"," 500 \n"," 0.271700 \n"," \n"," \n"," 600 \n"," 0.268900 \n"," \n"," \n"," 700 \n"," 0.263000 \n"," \n"," \n"," \n","800 \n"," 0.262800 \n","
Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.