import ffmpegio import gc import torch from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation from config import FPS_DIV, MAX_LENGTH, BATCH_SIZE, MODEL_PATH class PreprocessModel(torch.nn.Module): device = 'cpu' def __init__(self): super().__init__() self.feature_extractor = MobileViTImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small") self.mobile_vit = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small") self.convs = torch.nn.Sequential( torch.nn.MaxPool2d(2, 2) ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.mobile_vit(x).logits x = self.convs(x) return x def read_video(self, path: str) -> torch.Tensor: """ Читает видео и возвращает тензор с фичами """ _, video = ffmpegio.video.read(path, t=1.0) video = video[::FPS_DIV][:MAX_LENGTH] out_seg_video = [] for i in range(0, video.shape[0], BATCH_SIZE): frames = [video[j] for j in range(i, min(i + BATCH_SIZE, video.shape[0]))] frames = self.feature_extractor(images=frames, return_tensors='pt')['pixel_values'] out = self.forward(frames.to(self.device)).detach().to('cpu') out_seg_video.append(out) del frames, out gc.collect() if self.device == 'cuda': torch.cuda.empty_cache() return torch.cat(out_seg_video) class VideoModel(torch.nn.Module): def __init__(self): super().__init__() p = 0.5 self.pic_cnn = torch.nn.Sequential( torch.nn.Conv2d(21, 128, (2, 2), stride=2), torch.nn.BatchNorm2d(128), torch.nn.LeakyReLU(), torch.nn.Conv2d(128, 256, (2, 2), stride=2), torch.nn.BatchNorm2d(256), torch.nn.Dropout2d(p), torch.nn.LeakyReLU(), torch.nn.Conv2d(256, 256, (4, 4), stride=2), torch.nn.BatchNorm2d(256), torch.nn.Dropout2d(p), torch.nn.Flatten() ) self.vid_cnn = torch.nn.Sequential( torch.nn.Conv2d(21, 128, (2, 2), stride=2), torch.nn.BatchNorm2d(128), torch.nn.Tanh(), torch.nn.Conv2d(128, 256, (2, 2), stride=2), torch.nn.BatchNorm2d(256), torch.nn.Dropout2d(p), torch.nn.LeakyReLU(), torch.nn.Conv2d(256, 512, (2, 2), stride=2), torch.nn.BatchNorm2d(512), torch.nn.Dropout2d(p), torch.nn.Flatten() ) self.lstm = torch.nn.LSTM(2048, 256, 1, batch_first=True, bidirectional=True) self.fc1 = torch.nn.Linear(256 * 2, 1024) self.fc_norm = torch.nn.BatchNorm1d(256 * 2) self.tanh = torch.nn.Tanh() self.fc2 = torch.nn.Linear(1024, 2) self.sigmoid = torch.nn.Sigmoid() self.dropout = torch.nn.Dropout(p) # xaiver init for m in self.modules(): if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv3d): torch.nn.init.xavier_uniform_(m.weight) if m.bias is not None: torch.nn.init.zeros_(m.bias) elif isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight) if m.bias is not None: torch.nn.init.zeros_(m.bias) def forward(self, video: torch.Tensor) -> torch.Tensor: """ Использует превью как начальное скрытое состояние, а кадры видео как последовательность. video[0] - превью, video[1] - видео :param video: torch.Tensor, shape = (batch_size, frames + 1, 1344) """ frames = video.shape[0] video = torch.nn.functional.pad(video, (0, 0, 0, 0, 0, 0, MAX_LENGTH + 1 - frames, 0)) video = video.unsqueeze(0) _batch_size = video.shape[0] _preview = video[:, 0, :, :] _video = video[:, 1:, :, :] h0 = self.pic_cnn(_preview).unsqueeze(0) h0 = torch.nn.functional.pad(h0, (0, 0, 0, 0, 0, 1)) c0 = torch.zeros_like(h0) _video = self.vid_cnn(_video.reshape(-1, 21, 16, 16)) _video = _video.reshape(_batch_size, 90, -1) context, _ = self.lstm(_video, (h0, c0)) out = self.fc_norm(context[:, -1]) out = self.tanh(self.fc1(out)) out = self.dropout(out) out = self.sigmoid(self.fc2(out)) return out # @st.cache_resource class TikTokAnalytics(torch.nn.Module): def __init__(self): super().__init__() self.preprocessing_model = PreprocessModel() self.predict_model = torch.load(MODEL_PATH, map_location=self.preprocessing_model.device) self.preprocessing_model.eval() self.predict_model.eval() def forward(self, path: str) -> torch.Tensor: """ Вызываем препроцесс, потом предикт :param path: :return: """ tensor = self.preprocessing_model.read_video(path) predict = self.predict_model(tensor) return predict # if __name__ == '__main__': # model = TikTokAnalytics() # model = model( # '/Users/victorbarbarich/PycharmProjects/nueramic/vktrbr-video-tiktok/data/videos/video-6930454291186502917.mp4') # print(model)