{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "0-7S1J6Jq7nc" }, "source": [ "# Fine-Tuning BERT as a `RewardModel`\n", "\n", "1. First, intall `transformers`, `tlr`, and `codecarbon`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Fx7pg9eT62-d", "outputId": "94385892-b91f-4ffc-d290-9ace4a031edd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/7.5 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.1/7.5 MB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:04\u001b[0m\r\u001b[2K \u001b[91m━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.9/7.5 MB\u001b[0m \u001b[31m13.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━\u001b[0m \u001b[32m6.5/7.5 MB\u001b[0m \u001b[31m62.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m68.6 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/118.0 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m118.0/118.0 kB\u001b[0m \u001b[31m16.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.8/179.8 kB\u001b[0m \u001b[31m24.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m29.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m113.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m79.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.2/251.2 kB\u001b[0m \u001b[31m28.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.3/519.3 kB\u001b[0m \u001b[31m45.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m66.4/66.4 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m15.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m18.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "!pip install transformers trl codecarbon -q" ] }, { "cell_type": "markdown", "metadata": { "id": "Y6xzGtxPrMaF" }, "source": [ "2. Downloas the `reward-aira-dataset-comparisons` from the Hub." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 866, "referenced_widgets": [ "3ed53f09aa3448ae81c3740b9707e01b", "33162c57c3a04c7aaae2ad7146a200a2", "2f6a2fc7b4b14d1cb4e5db872ff5e1d9", "7c5c018831e9406db909f4f4ebb1ded9", "24de6bc06a304de2a371953eeda59169", "36cca3ef6beb4a76b939ef19ddc45c63", "7d765fb10dc049f2b2d38f986b2b44c8", "bd039e29382a47ccbf5a44e44ac38800", "ac8931e71d934f68b35b46d11f43a2ec", "e8466c20a9594023b3230d7a647cb996", "e5182c852759477e9e4c1a175f96dcbd", "271a3c73cf4d4b158ee6555192bd075c", "56d8059e77fb4c2587d2c33aa5c6e935", "30485b63af0f4038b363838d0e7882d1", "ba68148e28e24d8bac3d61ba265457aa", "0e015bd6c62c4cc1a202ac6ad09938d3", "a27b981a56d24a9ca11aede7b641352e", "148bc596ddf04b98b700a9b836b90b7f", "876e92b8532f424caf799b2384ea4825", "1de22c9f48d24d2d97d791c5b27076f2", "6dacf4aa1cbd4e7d8356ce7138858663", "1808f7cf30cb496e9bcc3bbea380e130", "9edd156a8fe54a6abe3228f21f6aa3bd", "fde4e0d3a84c476396a7795c0e1e9f4f", "765d971fb0f246c7a33bc5c8caf99d88", "324e43df3c664ddebc776da0373e1e53", "88a9ce9ea47a4da2a4e9921916ebfa50", "2e5f4f1b30af4a479f105a267d05b45c", "163127a00a894508ab22885bd90c9e1e", "7a2cf0e77def4eeca7fb5f59892c0b9b", "2f6a628049504b4f84fecc09b890c012", "a31290cc79094b7d8198ad665f8fe7df", "db0d10badc5d4c919eba59a5a31f83ab", "85438f8143cc404490ad6e8c97072c60", "05cf805e41d0464fb457442b865e311c", "42d94fbe632b42a3a3777d39ce9b327d", "85af8e22e6db4691aef79765b3ececc8", "77beee46872247f689747e109207e8dd", "7bc1f5d9ea4140669b098d01fa9d3bfb", "9788d5d6805c436ca192489391d11b0d", "86717791c5134643a6d095ac83fcb75f", "a6c21f5931504afbbd6651ac9fc85bc4", "168087c7f9af4aee86e6161c41c99902", "cb9e9564119b496999f168055b437fef", "3ca2c463079042ac82ca9b3545ef3ad7", "25845ffc110e4a1983dbfabe0cfc213d", "53be4a644592481e82990357667db419", "88991cb41ceb4198b4d21c74e5b5394a", "8bec74da2f564475bf69993a053e7985", "e3c288964461417ea2d1e7c749f2c0fc", "9b7b50a08efb4dc288f71a4c9e725adf", "a91153cb16b641618e44aca76a666a5c", "2f74835fd6f34b32ba708ed89ec8a137", "71de463b665143ee8155f851cb471e6e", "c895d827c3d24ef39ef550ae46148c31", "2fdf16db277e490485934be2e97c8b06", "6e1eaeabaf0c48bd8562e513fc14b4f6", "f70cc44dde3b4c9fa2349d4a1c1300b8", "7b274cda06004742b684bbe227b4d788", "84e99e7276a44aa788e73cce37347dde", "03b444e868444b4490138832f52f782b", "5155dca6afb5492aac4d7b3968d01b23", "e59f650442c949a593de023861f8e1f0", "77ccc5b558b14a14b2c0ed8b6296c03a", "fb5c571c1fe6444a89e42f00f0562081", "ed37e247616e4c98919bd74dac6d6d6d", "4fdc8aa6113a47e1ac48f5047429c3c9", "73ef968af1f94d1a9667c79bdff4129b", "822f06d553324eeb84be5a54fd4bf6c9", "5a358855ebc546b081c3b575b8873370", "c645a75f43004a08a7cfbf1b1bcca18d", "648489323d914f5e93496f33ed47ca57", "fe27222f499a4f2f8811717081f477f0", "cd9d87a9b7ac4c79a45eb41a72a8bf17", "de4833931d464e45a14fb5ad970b6c2c", "f35059934a004b759b8ae25637ce17a0", "92a9100a9c2b44b4ad54d2b2d8d53037" ] }, "id": "DtCgCgEr62C9", "outputId": "438cc273-60b6-4c4b-a320-336803bf0397" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3ed53f09aa3448ae81c3740b9707e01b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading readme: 0%| | 0.00/2.80k [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "271a3c73cf4d4b158ee6555192bd075c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data files: 0%| | 0/2 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9edd156a8fe54a6abe3228f21f6aa3bd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/31.2M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "85438f8143cc404490ad6e8c97072c60", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/34.3M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3ca2c463079042ac82ca9b3545ef3ad7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Extracting data files: 0%| | 0/2 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2fdf16db277e490485934be2e97c8b06", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating english split: 0%| | 0/32675 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4fdc8aa6113a47e1ac48f5047429c3c9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating portuguese split: 0%| | 0/32675 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Dataset loaded.\n" ] }, { "data": { "text/html": [ "\n", "
\n", " | instruction | \n", "chosen_response | \n", "rejected_response | \n", "
---|---|---|---|
0 | \n", "Estarei em Roma no próximo mês e gostaria de s... | \n", "Que bom saber que você visitará Roma! Roma est... | \n", "Roma é uma cidade repleta de atrações e pontos... | \n", "
1 | \n", "Como determino o valor de um item antigo. | \n", "Determinar o valor de um item antigo pode ser ... | \n", "Para determinar o valor de um item antigo, pes... | \n", "
2 | \n", "A engenharia de recursos sempre melhora o dese... | \n", "A engenharia de recursos é uma etapa crítica n... | \n", "Para otimizar o desempenho do aprendizado de m... | \n", "
3 | \n", "Qual é a diferença entre um oval e um círculo. | \n", "Um oval é uma forma esticada que é mais longa ... | \n", "A principal diferença entre um oval e um círcu... | \n", "
4 | \n", "Qual é a história do Canal de Suez. | \n", "O Canal de Suez é uma via navegável artificial... | \n", "O Canal de Suez é uma via navegável artificial... | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
32670 | \n", "É verdade que o núcleo da Terra é feito de ferro? | \n", "Sim, é geralmente aceito que o núcleo da Terra... | \n", "Sim, é verdade que o núcleo da Terra é feito d... | \n", "
32671 | \n", "Estou fazendo um experimento científico, mas n... | \n", "Claro! Ficarei feliz em ajudá-lo a solucionar ... | \n", "Sim, certamente posso ajudá-lo a solucionar o ... | \n", "
32672 | \n", "Quanto exercício preciso fazer por dia. | \n", "A quantidade de exercício que você precisa faz... | \n", "Depende do seu estilo de vida e objetivos de c... | \n", "
32673 | \n", "Quais são os erros mais comuns no planejamento... | \n", "Existem vários erros comuns que as pessoas pod... | \n", "Não fazer orçamento para despesas inesperadas.... | \n", "
32674 | \n", "Estou planejando um acampamento para este fim ... | \n", "Claro! Ficarei feliz em ajudá-lo a encontrar a... | \n", "Sugiro pesquisar acampamentos próximos à sua l... | \n", "
32675 rows × 3 columns
\n", "Step | \n", "Training Loss | \n", "
---|---|
200 | \n", "0.039600 | \n", "
400 | \n", "0.008300 | \n", "
600 | \n", "0.007400 | \n", "
800 | \n", "0.006200 | \n", "
1000 | \n", "0.002000 | \n", "
1200 | \n", "0.001300 | \n", "
"
],
"text/plain": [
"