temporary0-0name commited on
Commit
17f3846
1 Parent(s): c8a3ed0

Create generation.py

Browse files
Files changed (1) hide show
  1. generation.py +51 -0
generation.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+ from gpt_class import GPTConfig, GPT
4
+ # Assuming tiktoken is correctly imported and functions as expected
5
+ import tiktoken
6
+
7
+ # Setup device
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # Load model
11
+ state_dict = torch.load('model_51999.pt', map_location=device)
12
+ config = state_dict['config']
13
+ model = GPT(config)
14
+ model.load_state_dict(state_dict['model'])
15
+ model.to(device)
16
+ model.eval()
17
+
18
+ # Set seed for reproducibility
19
+ torch.manual_seed(42)
20
+ torch.cuda.manual_seed_all(42)
21
+
22
+ # Get tokenizer
23
+ tokenizer = tiktoken.get_encoding("gpt2")
24
+
25
+ def Generate(model, tokenizer, example, num_return_sequences, max_length):
26
+ model.eval()
27
+ tokens = tokenizer.encode(example)
28
+ tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).repeat(num_return_sequences, 1)
29
+ tokens = tokens.to(device)
30
+ sample_rng = torch.Generator(device=device)
31
+
32
+ xgen = tokens
33
+ while xgen.size(1) < max_length:
34
+ with torch.no_grad():
35
+ with torch.autocast(device_type=device):
36
+ logits, _ = model(xgen) # Assumes model returns logits and optional loss
37
+ logits = logits[:, -1, :] # Get last token logits
38
+ probs = F.softmax(logits, dim=-1)
39
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
40
+ ix = torch.multinomial(topk_probs, 1, generator=sample_rng)
41
+ xcol = torch.gather(topk_indices, -1, ix)
42
+ xgen = torch.cat((xgen, xcol), dim=1)
43
+
44
+ # Generate output for each sequence
45
+ for i in range(num_return_sequences):
46
+ tokens = xgen[i, :max_length].tolist()
47
+ decoded = tokenizer.decode(tokens)
48
+ print(f"Sample {i+1}: {decoded}")
49
+
50
+ # Generate text
51
+ Generate(model, tokenizer, example="It is raining outside and", num_return_sequences=4, max_length=64)