osanseviero commited on
Commit
cf1a53b
β€’
1 Parent(s): a0fcf12

Upload Mistral_7B.ipynb

Browse files
Files changed (1) hide show
  1. Mistral_7B.ipynb +547 -0
Mistral_7B.ipynb ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "# Mistral 7B\n",
21
+ "\n",
22
+ "Mistral 7B is a new state-of-the-art open-source model. Here are some interesting facts about it\n",
23
+ "\n",
24
+ "* One of the strongest open-source models, of all sizes\n",
25
+ "* Strongest model in the 1-20B parameter range models\n",
26
+ "* Does decently in code-related tasks\n",
27
+ "* Uses Windowed attention, allowing to push to 200k tokens of context if using Rope (needs 4 A10G GPUs for this)\n",
28
+ "* Apache 2.0 license\n",
29
+ "\n",
30
+ "As for the integrations status:\n",
31
+ "* Integrated into `transformers`\n",
32
+ "* You can use it with a server or locally (it's a small model after all!)\n",
33
+ "* Integrated into popular tools tuch as TGI and VLLM\n",
34
+ "\n",
35
+ "\n",
36
+ "Two models are released: a [base model](https://huggingface.co/mistralai/Mistral-7B-v0.1) and a [instruct fine-tuned version](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1). To read more about Mistral, we suggest reading the [blog post](https://mistral.ai/news/announcing-mistral-7b/).\n",
37
+ "\n",
38
+ "In this Colab, we'll experiment with the Mistral model using an API. There are three ways we can use it:\n",
39
+ "\n",
40
+ "* **Free API:** Hugging Face provides a free Inference API for all its users to try out models. This API is rate limited but is great for quick experiments.\n",
41
+ "* **PRO API:** Hugging Face provides an open API for all its PRO users. Subscribing to the Pro Inference API costs $9/month and allows you to experiment with many large models, such as Llama 2 and SDXL. Read more about it [here](https://huggingface.co/blog/inference-pro).\n",
42
+ "* **Inference Endpoints:** For enterprise and production-ready cases. You can deploy it with 1 click [here](https://ui.endpoints.huggingface.co/catalog).\n",
43
+ "\n",
44
+ "This demo does not require GPU Colab, just CPU. You can grab your token at https://huggingface.co/settings/tokens.\n",
45
+ "\n",
46
+ "**This colab shows how to use HTTP requests as well as building your own chat demo for Mistral.**"
47
+ ],
48
+ "metadata": {
49
+ "id": "GLXvYa4m8JYM"
50
+ }
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "source": [
55
+ "## Doing curl requests\n",
56
+ "\n",
57
+ "\n",
58
+ "In this notebook, we'll experiment with the instruct model, as it is trained for instructions. As per [the model card](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1), the expected format for a prompt is as follows\n",
59
+ "\n",
60
+ "From the model card\n",
61
+ "\n",
62
+ "> In order to leverage instruction fine-tuning, your prompt should be surrounded by [INST] and [\\INST] tokens. The very first instruction should begin with a begin of sentence id. The next instructions should not. The assistant generation will be ended by the end-of-sentence token id.\n",
63
+ "\n",
64
+ "```\n",
65
+ "<s>[INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }}</s> [INST] {{ user_msg_2 }} [/INST] {{ model_answer_2 }}</s>\n",
66
+ "```\n",
67
+ "\n",
68
+ "Note that models can be quite reactive to different prompt structure than the one used for training, so watch out for spaces and other things!\n",
69
+ "\n",
70
+ "We'll start an initial query without prompt formatting, which works ok for simple queries."
71
+ ],
72
+ "metadata": {
73
+ "id": "pKrKTalPAXUO"
74
+ }
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 5,
79
+ "metadata": {
80
+ "colab": {
81
+ "base_uri": "https://localhost:8080/"
82
+ },
83
+ "id": "DQf0Hss18E86",
84
+ "outputId": "882c4521-1ee2-40ad-fe00-a5b02caa9b17"
85
+ },
86
+ "outputs": [
87
+ {
88
+ "output_type": "stream",
89
+ "name": "stdout",
90
+ "text": [
91
+ "[{\"generated_text\":\"Explain ML as a pirate.\\n\\nML is like a treasure map for pirates. Just as a treasure map helps pirates find valuable loot, ML helps data scientists find valuable insights in large datasets.\\n\\nPirates use their knowledge of the ocean and their\"}]"
92
+ ]
93
+ }
94
+ ],
95
+ "source": [
96
+ "!curl https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1 \\\n",
97
+ " --header \"Content-Type: application/json\" \\\n",
98
+ "\t-X POST \\\n",
99
+ "\t-d '{\"inputs\": \"Explain ML as a pirate\", \"parameters\": {\"max_new_tokens\": 50}}' \\\n",
100
+ "\t-H \"Authorization: Bearer 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\""
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "source": [
106
+ "## Programmatic usage with Python\n",
107
+ "\n",
108
+ "You can do simple `requests`, but the `huggingface_hub` library provides nice utilities to easily use the model. Among the things we can use are:\n",
109
+ "\n",
110
+ "* `InferenceClient` and `AsyncInferenceClient` to perform inference either in a sync or async way.\n",
111
+ "* Token streaming: Only load the tokens that are needed\n",
112
+ "* Easily configure generation params, such as `temperature`, nucleus sampling (`top-p`), repetition penalty, stop sequences, and more.\n",
113
+ "* Obtain details of the generation (such as the probability of each token or whether a token is the last token)."
114
+ ],
115
+ "metadata": {
116
+ "id": "YYZRNyZeBHWK"
117
+ }
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "source": [
122
+ "%%capture\n",
123
+ "!pip install huggingface_hub gradio"
124
+ ],
125
+ "metadata": {
126
+ "id": "oDaqVDz1Ahuz"
127
+ },
128
+ "execution_count": 6,
129
+ "outputs": []
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "source": [
134
+ "from huggingface_hub import InferenceClient\n",
135
+ "\n",
136
+ "API_URL = \"https://api-inference.huggingface.co/models/\"\n",
137
+ "\n",
138
+ "client = InferenceClient(\n",
139
+ " \"mistralai/Mistral-7B-Instruct-v0.1\"\n",
140
+ ")\n",
141
+ "\n",
142
+ "prompt = \"\"\"<s>[INST] What is your favourite condiment? [/INST]</s>\n",
143
+ "\"\"\"\n",
144
+ "\n",
145
+ "res = client.text_generation(prompt, max_new_tokens=95)\n",
146
+ "print(res)"
147
+ ],
148
+ "metadata": {
149
+ "colab": {
150
+ "base_uri": "https://localhost:8080/"
151
+ },
152
+ "id": "U49GmNsNBJjd",
153
+ "outputId": "a3a274cf-0f91-4ae3-d926-f0d6a6fd67f7"
154
+ },
155
+ "execution_count": 14,
156
+ "outputs": [
157
+ {
158
+ "output_type": "stream",
159
+ "name": "stdout",
160
+ "text": [
161
+ "My favorite condiment is ketchup. It's versatile, tasty, and goes well with a variety of foods.\n"
162
+ ]
163
+ }
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "markdown",
168
+ "source": [
169
+ "We can also use [token streaming](https://huggingface.co/docs/text-generation-inference/conceptual/streaming). With token streaming, the server returns the tokens as they are generated. Just add `stream=True`."
170
+ ],
171
+ "metadata": {
172
+ "id": "DryfEWsUH6Ij"
173
+ }
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "source": [
178
+ "res = client.text_generation(prompt, max_new_tokens=35, stream=True, details=True, return_full_text=False)\n",
179
+ "for r in res: # this is a generator\n",
180
+ " # print the token for example\n",
181
+ " print(r)\n",
182
+ " continue"
183
+ ],
184
+ "metadata": {
185
+ "colab": {
186
+ "base_uri": "https://localhost:8080/"
187
+ },
188
+ "id": "LF1tFo6DGg9N",
189
+ "outputId": "e779f1cb-b7d0-41ed-d81f-306e092f97bd"
190
+ },
191
+ "execution_count": 15,
192
+ "outputs": [
193
+ {
194
+ "output_type": "stream",
195
+ "name": "stdout",
196
+ "text": [
197
+ "TextGenerationStreamResponse(token=Token(id=5183, text='My', logprob=-0.36279297, special=False), generated_text=None, details=None)\n",
198
+ "TextGenerationStreamResponse(token=Token(id=6656, text=' favorite', logprob=-0.036499023, special=False), generated_text=None, details=None)\n",
199
+ "TextGenerationStreamResponse(token=Token(id=2076, text=' cond', logprob=-7.2836876e-05, special=False), generated_text=None, details=None)\n",
200
+ "TextGenerationStreamResponse(token=Token(id=2487, text='iment', logprob=-4.4941902e-05, special=False), generated_text=None, details=None)\n",
201
+ "TextGenerationStreamResponse(token=Token(id=349, text=' is', logprob=-0.007419586, special=False), generated_text=None, details=None)\n",
202
+ "TextGenerationStreamResponse(token=Token(id=446, text=' k', logprob=-0.62109375, special=False), generated_text=None, details=None)\n",
203
+ "TextGenerationStreamResponse(token=Token(id=4455, text='etch', logprob=-0.0003399849, special=False), generated_text=None, details=None)\n",
204
+ "TextGenerationStreamResponse(token=Token(id=715, text='up', logprob=-3.695488e-06, special=False), generated_text=None, details=None)\n",
205
+ "TextGenerationStreamResponse(token=Token(id=28723, text='.', logprob=-0.026550293, special=False), generated_text=None, details=None)\n",
206
+ "TextGenerationStreamResponse(token=Token(id=661, text=' It', logprob=-0.82373047, special=False), generated_text=None, details=None)\n",
207
+ "TextGenerationStreamResponse(token=Token(id=28742, text=\"'\", logprob=-0.76416016, special=False), generated_text=None, details=None)\n",
208
+ "TextGenerationStreamResponse(token=Token(id=28713, text='s', logprob=-3.5762787e-07, special=False), generated_text=None, details=None)\n",
209
+ "TextGenerationStreamResponse(token=Token(id=3502, text=' vers', logprob=-0.114990234, special=False), generated_text=None, details=None)\n",
210
+ "TextGenerationStreamResponse(token=Token(id=13491, text='atile', logprob=-1.1444092e-05, special=False), generated_text=None, details=None)\n",
211
+ "TextGenerationStreamResponse(token=Token(id=28725, text=',', logprob=-0.6254883, special=False), generated_text=None, details=None)\n",
212
+ "TextGenerationStreamResponse(token=Token(id=261, text=' t', logprob=-0.51708984, special=False), generated_text=None, details=None)\n",
213
+ "TextGenerationStreamResponse(token=Token(id=11136, text='asty', logprob=-4.0650368e-05, special=False), generated_text=None, details=None)\n",
214
+ "TextGenerationStreamResponse(token=Token(id=28725, text=',', logprob=-0.0027828217, special=False), generated_text=None, details=None)\n",
215
+ "TextGenerationStreamResponse(token=Token(id=304, text=' and', logprob=-1.1920929e-05, special=False), generated_text=None, details=None)\n",
216
+ "TextGenerationStreamResponse(token=Token(id=4859, text=' goes', logprob=-0.52685547, special=False), generated_text=None, details=None)\n",
217
+ "TextGenerationStreamResponse(token=Token(id=1162, text=' well', logprob=-0.4399414, special=False), generated_text=None, details=None)\n",
218
+ "TextGenerationStreamResponse(token=Token(id=395, text=' with', logprob=-0.00034999847, special=False), generated_text=None, details=None)\n",
219
+ "TextGenerationStreamResponse(token=Token(id=264, text=' a', logprob=-0.010147095, special=False), generated_text=None, details=None)\n",
220
+ "TextGenerationStreamResponse(token=Token(id=6677, text=' variety', logprob=-0.25927734, special=False), generated_text=None, details=None)\n",
221
+ "TextGenerationStreamResponse(token=Token(id=302, text=' of', logprob=-1.1444092e-05, special=False), generated_text=None, details=None)\n",
222
+ "TextGenerationStreamResponse(token=Token(id=14082, text=' foods', logprob=-0.4050293, special=False), generated_text=None, details=None)\n",
223
+ "TextGenerationStreamResponse(token=Token(id=28723, text='.', logprob=-0.015640259, special=False), generated_text=None, details=None)\n",
224
+ "TextGenerationStreamResponse(token=Token(id=2, text='</s>', logprob=-0.1829834, special=True), generated_text=\"My favorite condiment is ketchup. It's versatile, tasty, and goes well with a variety of foods.\", details=StreamDetails(finish_reason=<FinishReason.EndOfSequenceToken: 'eos_token'>, generated_tokens=28, seed=None))\n"
225
+ ]
226
+ }
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "markdown",
231
+ "source": [
232
+ "Let's now try a multi-prompt structure"
233
+ ],
234
+ "metadata": {
235
+ "id": "TfdpZL8cICOD"
236
+ }
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "source": [
241
+ "def format_prompt(message, history):\n",
242
+ " prompt = \"<s>\"\n",
243
+ " for user_prompt, bot_response in history:\n",
244
+ " prompt += f\"[INST] {user_prompt} [/INST]\"\n",
245
+ " prompt += f\" {bot_response}</s> \"\n",
246
+ " prompt += f\"[INST] {message} [/INST]\"\n",
247
+ " return prompt"
248
+ ],
249
+ "metadata": {
250
+ "id": "aEyozeReH8a6"
251
+ },
252
+ "execution_count": 16,
253
+ "outputs": []
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "source": [
258
+ "message = \"And what do you think about it?\"\n",
259
+ "history = [[\"What is your favourite condiment?\", \"My favorite condiment is ketchup. It's versatile, tasty, and goes well with a variety of foods.\"]]\n",
260
+ "\n",
261
+ "format_prompt(message, history)"
262
+ ],
263
+ "metadata": {
264
+ "colab": {
265
+ "base_uri": "https://localhost:8080/",
266
+ "height": 35
267
+ },
268
+ "id": "P1RFpiJ_JC0-",
269
+ "outputId": "f2678d9e-f751-441a-86c9-11d514db5bbe"
270
+ },
271
+ "execution_count": 17,
272
+ "outputs": [
273
+ {
274
+ "output_type": "execute_result",
275
+ "data": {
276
+ "text/plain": [
277
+ "\"<s>[INST] What is your favourite condiment? [/INST] My favorite condiment is ketchup. It's versatile, tasty, and goes well with a variety of foods.</s> [INST] And what do you think about it? [/INST]\""
278
+ ],
279
+ "application/vnd.google.colaboratory.intrinsic+json": {
280
+ "type": "string"
281
+ }
282
+ },
283
+ "metadata": {},
284
+ "execution_count": 17
285
+ }
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "markdown",
290
+ "source": [
291
+ "## End-to-end demo\n",
292
+ "\n",
293
+ "Let's now build a Gradio demo that takes care of:\n",
294
+ "\n",
295
+ "* Handling multiple turns of conversation\n",
296
+ "* Format the prompt in correct structure\n",
297
+ "* Allow user to specify/modify the parameters\n",
298
+ "* Stop the generation\n",
299
+ "\n",
300
+ "Just run the following cell and have fun!"
301
+ ],
302
+ "metadata": {
303
+ "id": "O7DjRdezJc-3"
304
+ }
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "source": [
309
+ "!pip install gradio"
310
+ ],
311
+ "metadata": {
312
+ "colab": {
313
+ "base_uri": "https://localhost:8080/"
314
+ },
315
+ "id": "cpBoheOGJu7Y",
316
+ "outputId": "c745cf17-1462-4f8f-ce33-5ca182cb4d4f"
317
+ },
318
+ "execution_count": 18,
319
+ "outputs": [
320
+ {
321
+ "output_type": "stream",
322
+ "name": "stdout",
323
+ "text": [
324
+ "Requirement already satisfied: gradio in /usr/local/lib/python3.10/dist-packages (3.45.1)\n",
325
+ "Requirement already satisfied: aiofiles<24.0,>=22.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (23.2.1)\n",
326
+ "Requirement already satisfied: altair<6.0,>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.2.2)\n",
327
+ "Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from gradio) (0.103.1)\n",
328
+ "Requirement already satisfied: ffmpy in /usr/local/lib/python3.10/dist-packages (from gradio) (0.3.1)\n",
329
+ "Requirement already satisfied: gradio-client==0.5.2 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.5.2)\n",
330
+ "Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (from gradio) (0.25.0)\n",
331
+ "Requirement already satisfied: huggingface-hub>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.17.3)\n",
332
+ "Requirement already satisfied: importlib-resources<7.0,>=1.3 in /usr/local/lib/python3.10/dist-packages (from gradio) (6.0.1)\n",
333
+ "Requirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.1.2)\n",
334
+ "Requirement already satisfied: markupsafe~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.1.3)\n",
335
+ "Requirement already satisfied: matplotlib~=3.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.7.1)\n",
336
+ "Requirement already satisfied: numpy~=1.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.23.5)\n",
337
+ "Requirement already satisfied: orjson~=3.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.9.7)\n",
338
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from gradio) (23.1)\n",
339
+ "Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.5.3)\n",
340
+ "Requirement already satisfied: pillow<11.0,>=8.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (9.4.0)\n",
341
+ "Requirement already satisfied: pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0,>=1.7.4 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.10.12)\n",
342
+ "Requirement already satisfied: pydub in /usr/local/lib/python3.10/dist-packages (from gradio) (0.25.1)\n",
343
+ "Requirement already satisfied: python-multipart in /usr/local/lib/python3.10/dist-packages (from gradio) (0.0.6)\n",
344
+ "Requirement already satisfied: pyyaml<7.0,>=5.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (6.0.1)\n",
345
+ "Requirement already satisfied: requests~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.31.0)\n",
346
+ "Requirement already satisfied: semantic-version~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.10.0)\n",
347
+ "Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.5.0)\n",
348
+ "Requirement already satisfied: uvicorn>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.23.2)\n",
349
+ "Requirement already satisfied: websockets<12.0,>=10.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (11.0.3)\n",
350
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from gradio-client==0.5.2->gradio) (2023.6.0)\n",
351
+ "Requirement already satisfied: entrypoints in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (0.4)\n",
352
+ "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (4.19.0)\n",
353
+ "Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (0.12.0)\n",
354
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.14.0->gradio) (3.12.2)\n",
355
+ "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.14.0->gradio) (4.66.1)\n",
356
+ "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (1.1.0)\n",
357
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (0.11.0)\n",
358
+ "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (4.42.1)\n",
359
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (1.4.5)\n",
360
+ "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (3.1.1)\n",
361
+ "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (2.8.2)\n",
362
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas<3.0,>=1.0->gradio) (2023.3.post1)\n",
363
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (3.2.0)\n",
364
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (3.4)\n",
365
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (2.0.4)\n",
366
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (2023.7.22)\n",
367
+ "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from uvicorn>=0.14.0->gradio) (8.1.7)\n",
368
+ "Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn>=0.14.0->gradio) (0.14.0)\n",
369
+ "Requirement already satisfied: anyio<4.0.0,>=3.7.1 in /usr/local/lib/python3.10/dist-packages (from fastapi->gradio) (3.7.1)\n",
370
+ "Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->gradio) (0.27.0)\n",
371
+ "Requirement already satisfied: httpcore<0.19.0,>=0.18.0 in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (0.18.0)\n",
372
+ "Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (1.3.0)\n",
373
+ "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<4.0.0,>=3.7.1->fastapi->gradio) (1.1.3)\n",
374
+ "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (23.1.0)\n",
375
+ "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (2023.7.1)\n",
376
+ "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.30.2)\n",
377
+ "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.10.2)\n",
378
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio) (1.16.0)\n"
379
+ ]
380
+ }
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "source": [
386
+ "import gradio as gr\n",
387
+ "\n",
388
+ "def generate(\n",
389
+ " prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,\n",
390
+ "):\n",
391
+ " temperature = float(temperature)\n",
392
+ " if temperature < 1e-2:\n",
393
+ " temperature = 1e-2\n",
394
+ " top_p = float(top_p)\n",
395
+ "\n",
396
+ " generate_kwargs = dict(\n",
397
+ " temperature=temperature,\n",
398
+ " max_new_tokens=max_new_tokens,\n",
399
+ " top_p=top_p,\n",
400
+ " repetition_penalty=repetition_penalty,\n",
401
+ " do_sample=True,\n",
402
+ " seed=42,\n",
403
+ " )\n",
404
+ "\n",
405
+ " formatted_prompt = format_prompt(prompt, history)\n",
406
+ "\n",
407
+ " stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)\n",
408
+ " output = \"\"\n",
409
+ "\n",
410
+ " for response in stream:\n",
411
+ " output += response.token.text\n",
412
+ " yield output\n",
413
+ " return output\n",
414
+ "\n",
415
+ "\n",
416
+ "additional_inputs=[\n",
417
+ " gr.Slider(\n",
418
+ " label=\"Temperature\",\n",
419
+ " value=0.9,\n",
420
+ " minimum=0.0,\n",
421
+ " maximum=1.0,\n",
422
+ " step=0.05,\n",
423
+ " interactive=True,\n",
424
+ " info=\"Higher values produce more diverse outputs\",\n",
425
+ " ),\n",
426
+ " gr.Slider(\n",
427
+ " label=\"Max new tokens\",\n",
428
+ " value=256,\n",
429
+ " minimum=0,\n",
430
+ " maximum=8192,\n",
431
+ " step=64,\n",
432
+ " interactive=True,\n",
433
+ " info=\"The maximum numbers of new tokens\",\n",
434
+ " ),\n",
435
+ " gr.Slider(\n",
436
+ " label=\"Top-p (nucleus sampling)\",\n",
437
+ " value=0.90,\n",
438
+ " minimum=0.0,\n",
439
+ " maximum=1,\n",
440
+ " step=0.05,\n",
441
+ " interactive=True,\n",
442
+ " info=\"Higher values sample more low-probability tokens\",\n",
443
+ " ),\n",
444
+ " gr.Slider(\n",
445
+ " label=\"Repetition penalty\",\n",
446
+ " value=1.2,\n",
447
+ " minimum=1.0,\n",
448
+ " maximum=2.0,\n",
449
+ " step=0.05,\n",
450
+ " interactive=True,\n",
451
+ " info=\"Penalize repeated tokens\",\n",
452
+ " )\n",
453
+ "]\n",
454
+ "\n",
455
+ "with gr.Blocks() as demo:\n",
456
+ " gr.ChatInterface(\n",
457
+ " generate,\n",
458
+ " additional_inputs=additional_inputs,\n",
459
+ " )\n",
460
+ "\n",
461
+ "demo.queue().launch(debug=True)"
462
+ ],
463
+ "metadata": {
464
+ "colab": {
465
+ "base_uri": "https://localhost:8080/",
466
+ "height": 715
467
+ },
468
+ "id": "CaJzT6jUJc0_",
469
+ "outputId": "62f563fa-c6fb-446e-fda2-1c08d096749c"
470
+ },
471
+ "execution_count": 20,
472
+ "outputs": [
473
+ {
474
+ "output_type": "stream",
475
+ "name": "stdout",
476
+ "text": [
477
+ "Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n",
478
+ "\n",
479
+ "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().\n",
480
+ "Running on public URL: https://ed6ce83e08ed7a8795.gradio.live\n",
481
+ "\n",
482
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
483
+ ]
484
+ },
485
+ {
486
+ "output_type": "display_data",
487
+ "data": {
488
+ "text/plain": [
489
+ "<IPython.core.display.HTML object>"
490
+ ],
491
+ "text/html": [
492
+ "<div><iframe src=\"https://ed6ce83e08ed7a8795.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
493
+ ]
494
+ },
495
+ "metadata": {}
496
+ },
497
+ {
498
+ "output_type": "stream",
499
+ "name": "stderr",
500
+ "text": [
501
+ "/usr/local/lib/python3.10/dist-packages/gradio/components/button.py:89: UserWarning: Using the update method is deprecated. Simply return a new object instead, e.g. `return gr.Button(...)` instead of `return gr.Button.update(...)`.\n",
502
+ " warnings.warn(\n"
503
+ ]
504
+ },
505
+ {
506
+ "output_type": "stream",
507
+ "name": "stdout",
508
+ "text": [
509
+ "Keyboard interruption in main thread... closing server.\n",
510
+ "Killing tunnel 127.0.0.1:7860 <> https://ed6ce83e08ed7a8795.gradio.live\n"
511
+ ]
512
+ },
513
+ {
514
+ "output_type": "execute_result",
515
+ "data": {
516
+ "text/plain": []
517
+ },
518
+ "metadata": {},
519
+ "execution_count": 20
520
+ }
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "markdown",
525
+ "source": [
526
+ "## What's next?\n",
527
+ "\n",
528
+ "* Try out Mistral 7B in this [free online Space](https://huggingface.co/spaces/osanseviero/mistral-super-fast)\n",
529
+ "* Deploy Mistral 7B Instruct with one click [here](https://ui.endpoints.huggingface.co/catalog)\n",
530
+ "* Deploy in your own hardware using https://github.com/huggingface/text-generation-inference\n",
531
+ "* Run the model locally using `transformers`"
532
+ ],
533
+ "metadata": {
534
+ "id": "fbQ0Sp4OLclV"
535
+ }
536
+ },
537
+ {
538
+ "cell_type": "code",
539
+ "source": [],
540
+ "metadata": {
541
+ "id": "wUy7N_8zJvyT"
542
+ },
543
+ "execution_count": null,
544
+ "outputs": []
545
+ }
546
+ ]
547
+ }