File size: 563 Bytes
bdf9962
 
 
fc91aa0
bdf9962
fc91aa0
bdf9962
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
import spaces

@spaces.GPU
def load_model_and_processor(model_path):
    vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
    tokenizer = vl_chat_processor.tokenizer

    vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
        model_path, trust_remote_code=True, torch_dtype=torch.bfloat16
    )
    vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

    return vl_gpt, vl_chat_processor