pt-sk commited on
Commit
6e43012
1 Parent(s): 26208f1

Upload train_and_inference.ipynb

Browse files
Files changed (1) hide show
  1. code/train_and_inference.ipynb +107 -0
code/train_and_inference.ipynb ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "id": "_nSGBgG98qRC"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "# Trainer\n",
12
+ "import torch\n",
13
+ "from tqdm import tqdm\n",
14
+ "\n",
15
+ "iterator = tqdm(dataloader, desc=\"Training\", postfix={\"train_loss\":0.0})\n",
16
+ "\n",
17
+ "for item in iterator:\n",
18
+ " item = tokenizer.bos_token + \" \" + item[0] + \" \" + tokenizer.eos_token\n",
19
+ " encoded_inp = tokenizer(item, return_tensors='pt').input_ids.to(\"cuda\")\n",
20
+ " logits = mamba_model(encoded_inp)\n",
21
+ "\n",
22
+ " labels = encoded_inp.to(logits.device)\n",
23
+ " shift_logits = logits[:, :-1, :].contiguous()\n",
24
+ " labels = labels[:, 1:].contiguous()\n",
25
+ " loss_fct = torch.nn.CrossEntropyLoss()\n",
26
+ " loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))\n",
27
+ "\n",
28
+ " optimizer.zero_grad(set_to_none=True)\n",
29
+ " loss.backward()\n",
30
+ " optimizer.step()\n",
31
+ "\n",
32
+ " # moving data's from gpu to cpu\n",
33
+ " loss = loss.detach().cpu().numpy()\n",
34
+ " logits = logits.detach().cpu().numpy()\n",
35
+ " labels = labels.detach().cpu().numpy()\n",
36
+ " encoded_inp = encoded_inp.detach().cpu().numpy()\n",
37
+ " shift_logits = shift_logits.detach().cpu().numpy()\n",
38
+ "\n",
39
+ " iterator.set_postfix({\"train_loss\": loss.item()}, refresh=False)"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 14,
45
+ "metadata": {
46
+ "id": "feaR0XKtOGug"
47
+ },
48
+ "outputs": [],
49
+ "source": [
50
+ "# Inference\n",
51
+ "import torch\n",
52
+ "import torch.nn.functional as F\n",
53
+ "\n",
54
+ "\n",
55
+ "def generate(model,\n",
56
+ " tokenizer,\n",
57
+ " prompt: str,\n",
58
+ " n_tokens_to_gen: int = 200,\n",
59
+ " sample: bool = True,\n",
60
+ " top_k: int = 40):\n",
61
+ " model.eval()\n",
62
+ "\n",
63
+ " input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(\"cuda\")\n",
64
+ "\n",
65
+ " for token_n in range(n_tokens_to_gen):\n",
66
+ " with torch.no_grad():\n",
67
+ " indices_to_input = input_ids\n",
68
+ " next_token_logits = mamba_model(indices_to_input)[:, -1]\n",
69
+ "\n",
70
+ " probs = F.softmax(next_token_logits, dim=-1)\n",
71
+ " (batch, vocab_size) = probs.shape\n",
72
+ "\n",
73
+ " if top_k is not None:\n",
74
+ " (values, indices) = torch.topk(probs, k=top_k)\n",
75
+ " probs[probs < values[:, -1, None]] = 0\n",
76
+ " probs = probs / probs.sum(axis=1, keepdims=True)\n",
77
+ "\n",
78
+ " if sample:\n",
79
+ " next_indices = torch.multinomial(probs, num_samples=1)\n",
80
+ " else:\n",
81
+ " next_indices = torch.argmax(probs, dim=-1)[:, None]\n",
82
+ "\n",
83
+ " input_ids = torch.cat([input_ids, next_indices], dim=1)\n",
84
+ "\n",
85
+ " output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n",
86
+ "\n",
87
+ " return output_completions"
88
+ ]
89
+ }
90
+ ],
91
+ "metadata": {
92
+ "accelerator": "GPU",
93
+ "colab": {
94
+ "gpuType": "T4",
95
+ "provenance": []
96
+ },
97
+ "kernelspec": {
98
+ "display_name": "Python 3",
99
+ "name": "python3"
100
+ },
101
+ "language_info": {
102
+ "name": "python"
103
+ }
104
+ },
105
+ "nbformat": 4,
106
+ "nbformat_minor": 0
107
+ }