EnariGmbH commited on
Commit
a700cdc
1 Parent(s): 018de9e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +20 -5
handler.py CHANGED
@@ -1,13 +1,28 @@
1
  from typing import Dict, List, Any
2
  import torch
3
- from transformers import LlavaNextVideoForConditionalGeneration, AutoProcessor, AutoConfig
 
4
 
5
  class EndpointHandler:
6
- def __init__(self, path="/app"):
7
- self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(path)
 
8
 
9
- # Load the processor from the configuration files
10
- self.processor = AutoProcessor.from_pretrained(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Ensure the model is in evaluation mode
13
  self.model.eval()
 
1
  from typing import Dict, List, Any
2
  import torch
3
+ from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor
4
+ from peft import PeftModel
5
 
6
  class EndpointHandler:
7
+ def __init__(self):
8
+ self.base_model_name = "llava-hf/LLaVA-NeXT-Video-7B-hf"
9
+ self.adapter_model_name = "EnariGmbH/surftown-1.0"
10
 
11
+ # Load the base model
12
+ self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(
13
+ self.base_model_name,
14
+ torch_dtype=torch.float16,
15
+ device_map="auto"
16
+ )
17
+
18
+ # Load the fine-tuned adapter model into the base model
19
+ self.model = PeftModel.from_pretrained(self.model, self.adapter_model_name)
20
+
21
+ # Merge the adapter weights into the base model and unload the adapter
22
+ self.model = self.model.merge_and_unload()
23
+
24
+ # # Optionally, load and save the processor (if needed)
25
+ self.processor = LlavaNextVideoProcessor.from_pretrained(self.adapter_model_name)
26
 
27
  # Ensure the model is in evaluation mode
28
  self.model.eval()