File size: 1,943 Bytes
510f333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d02e038
510f333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f9ed5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import HfApi, login

class MultimodalAI:
    def __init__(self):
        # Obtain Hugging Face token in .env file
        self.HUGGINGFACE_TOKEN = os.environ["HUGGINGFACE_TOKEN"]

        # Check if the token is retrieved successfully
        if self.HUGGINGFACE_TOKEN is None:
            raise ValueError("HUGGINGFACE_TOKEN environment variable is not set.")

        # Authenticate with Hugging Face
        self.api = HfApi()
        login(token=self.HUGGINGFACE_TOKEN)

        # Model selection
        self.model_name = "Lin-Chen/sharegpt4video-8b"

        # Check if a CUDA-enabled GPU is available.
        # If available, move the model to the GPU (cuda:0) for faster computation.
        # Otherwise, move the model to the CPU.
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Load the model and tokenizer
        self._load_model_and_tokenizer()

    def _load_model_and_tokenizer(self):
        # Load LLama model and tokenizer
        self.model = AutoModelForCausalLM.from_pretrained(self.model_name, 
                                                          token=self.HUGGINGFACE_TOKEN).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, 
                                                       token=self.HUGGINGFACE_TOKEN)

    def generate_response(self, text_input, max_new_tokens=50):
        # Tokenize input text
        inputs = self.tokenizer(text_input, return_tensors="pt").to(self.device)

        # Generate response
        with torch.no_grad():
            outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens, pad_token_id=self.tokenizer.pad_token_id)

        # Decode and return the response
        response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response_text