File size: 966 Bytes
87a5635
3c6c416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87a5635
 
 
3c6c416
 
 
 
 
87a5635
 
 
 
3c6c416
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from typing import Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM

# Initialize the model and tokenizer variables as None
tokenizer = None
model = None


def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
    """
    Returns the preloaded model and tokenizer. If they haven't been loaded before, loads them.

    Returns:
        tuple: A tuple containing the preloaded model and tokenizer.
    """
    global model, tokenizer
    if model is None or tokenizer is None:
        # Set device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Load the tokenizer and the model
        tokenizer = AutoTokenizer.from_pretrained("juancopi81/lmd_8bars_tokenizer")
        model = AutoModelForCausalLM.from_pretrained(
            "juancopi81/lmd-8bars-2048-epochs20_v3"
        )

        # Move model to device
        model = model.to(device)

    return model, tokenizer