File size: 2,624 Bytes
31fc7e1
 
 
 
c7e92d2
 
 
31fc7e1
c7e92d2
31fc7e1
c7e92d2
31fc7e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7e92d2
31fc7e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
from cog import BasePredictor, Input, Path
import torch
import json
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.models.model import load_model
from src.dataset.video_utils import create_transform, extract_frames

CHECKPOINT_DIR = "checkpoints/"

class Predictor(BasePredictor):
    def setup(self):
        """Load the model into memory to make running multiple predictions efficient"""
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # Load configuration from JSON
        with open(
            os.path.join(CHECKPOINT_DIR, "config.json"), 'r') as f:
            self.config = json.load(f)
        
        # Create transform
        self.transform = create_transform(self.config, training=False)
        
        # Load model
        self.model = load_model(
            self.config['num_classes'],
            os.path.join(CHECKPOINT_DIR, "weights.ckpt"),
            self.device,
            self.config['clip_model']
        )
        self.model.eval()

    def predict(self, video: Path = Input(description="Input video file")) -> dict:
        """Run a single prediction on the model"""
        try:
            # Extract frames using shared function with config
            frames, success = extract_frames(
                str(video), 
                self.config, 
                self.transform
            )
            
            if not success or frames is None:
                raise ValueError(f"Failed to process video: {video}")
            
            # Now frames is a tensor, not a tuple
            frames = frames.unsqueeze(0).to(self.device)
            
            # Get prediction
            with torch.no_grad():
                output = self.model(frames)
                probabilities = torch.softmax(output, dim=1)
                predicted_class = torch.argmax(probabilities, dim=1).item()
                confidence = probabilities[0][predicted_class].item()
                
                # Get all class confidences
                all_confidences = {
                    label: probabilities[0][i].item()
                    for i, label in enumerate(self.config['class_labels'])
                }
            
            return {
                "class": self.config['class_labels'][predicted_class],
                "confidence": confidence,
                "all_confidences": all_confidences
            }
            
        except Exception as e:
            raise ValueError(f"Error processing video: {str(e)}")