""" File: load_models.py Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov Description: Load pretrained models. License: MIT License """ import math import numpy as np import cv2 import torch.nn.functional as F import torch.nn as nn import torch from typing import Optional from PIL import Image from ultralytics import YOLO from transformers.models.wav2vec2.modeling_wav2vec2 import ( Wav2Vec2Model, Wav2Vec2PreTrainedModel, ) from transformers import AutoConfig, Wav2Vec2Processor, AutoTokenizer, AutoModel from app.utils import pth_processing, get_idx_frames_in_windows # Importing necessary components for the Gradio app from app.utils import load_model class ScaledDotProductAttention_MultiHead(nn.Module): def __init__(self): super(ScaledDotProductAttention_MultiHead, self).__init__() self.softmax = nn.Softmax(dim=-1) def forward(self, query, key, value, mask=None): if mask is not None: raise ValueError("Mask is not supported yet") # key, query, value shapes: [batch_size, num_heads, seq_len, dim] emb_dim = key.shape[-1] # Calculate attention weights attention_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( emb_dim ) # masking if mask is not None: raise ValueError("Mask is not supported yet") # Softmax attention_weights = self.softmax(attention_weights) # modify value value = torch.matmul(attention_weights, value) return value, attention_weights class PositionWiseFeedForward(nn.Module): def __init__(self, input_dim, hidden_dim, dropout: float = 0.1): super().__init__() self.layer_1 = nn.Linear(input_dim, hidden_dim) self.layer_2 = nn.Linear(hidden_dim, input_dim) self.layer_norm = nn.LayerNorm(input_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): # feed-forward network x = self.layer_1(x) x = self.dropout(x) x = F.relu(x) x = self.layer_2(x) return x class Add_and_Norm(nn.Module): def __init__(self, input_dim, dropout: Optional[float] = 0.1): super().__init__() self.layer_norm = nn.LayerNorm(input_dim) if dropout is not None: self.dropout = nn.Dropout(dropout) def forward(self, x1, residual): x = x1 # apply dropout of needed if hasattr(self, "dropout"): x = self.dropout(x) # add and then norm x = x + residual x = self.layer_norm(x) return x class MultiHeadAttention(nn.Module): def __init__(self, input_dim, num_heads, dropout: Optional[float] = 0.1): super().__init__() self.input_dim = input_dim self.num_heads = num_heads if input_dim % num_heads != 0: raise ValueError("input_dim must be divisible by num_heads") self.head_dim = input_dim // num_heads self.dropout = dropout # initialize weights self.query_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False) self.keys_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False) self.values_w = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=False) self.ff_layer_after_concat = nn.Linear( self.num_heads * self.head_dim, input_dim, bias=False ) self.attention = ScaledDotProductAttention_MultiHead() if self.dropout is not None: self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values, mask=None): # query, keys, values shapes: [batch_size, seq_len, input_dim] batch_size, len_query, len_keys, len_values = ( queries.size(0), queries.size(1), keys.size(1), values.size(1), ) # linear transformation before attention queries = ( self.query_w(queries) .view(batch_size, len_query, self.num_heads, self.head_dim) .transpose(1, 2) ) # [batch_size, num_heads, seq_len, dim] keys = ( self.keys_w(keys) .view(batch_size, len_keys, self.num_heads, self.head_dim) .transpose(1, 2) ) # [batch_size, num_heads, seq_len, dim] values = ( self.values_w(values) .view(batch_size, len_values, self.num_heads, self.head_dim) .transpose(1, 2) ) # [batch_size, num_heads, seq_len, dim] # attention itself values, attention_weights = self.attention( queries, keys, values, mask=mask ) # values shape:[batch_size, num_heads, seq_len, dim] # concatenation out = ( values.transpose(1, 2) .contiguous() .view(batch_size, len_values, self.num_heads * self.head_dim) ) # [batch_size, seq_len, num_heads * dim = input_dim] # go through last linear layer out = self.ff_layer_after_concat(out) return out class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) ) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) pe = pe.permute( 1, 0, 2 ) # [seq_len, batch_size, embedding_dim] -> [batch_size, seq_len, embedding_dim] self.register_buffer("pe", pe) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Tensor, shape [batch_size, seq_len, embedding_dim] """ x = x + self.pe[:, : x.size(1)] return self.dropout(x) class TransformerLayer(nn.Module): def __init__( self, input_dim, num_heads, dropout: Optional[float] = 0.1, positional_encoding: bool = True, ): super(TransformerLayer, self).__init__() self.positional_encoding = positional_encoding self.input_dim = input_dim self.num_heads = num_heads self.head_dim = input_dim // num_heads self.dropout = dropout # initialize layers self.self_attention = MultiHeadAttention(input_dim, num_heads, dropout=dropout) self.feed_forward = PositionWiseFeedForward( input_dim, input_dim, dropout=dropout ) self.add_norm_after_attention = Add_and_Norm(input_dim, dropout=dropout) self.add_norm_after_ff = Add_and_Norm(input_dim, dropout=dropout) # calculate positional encoding if self.positional_encoding: self.positional_encoding = PositionalEncoding(input_dim) def forward(self, key, value, query, mask=None): # key, value, and query shapes: [batch_size, seq_len, input_dim] # positional encoding if self.positional_encoding: key = self.positional_encoding(key) value = self.positional_encoding(value) query = self.positional_encoding(query) # multi-head attention residual = query x = self.self_attention(queries=query, keys=key, values=value, mask=mask) x = self.add_norm_after_attention(x, residual) # feed forward residual = x x = self.feed_forward(x) x = self.add_norm_after_ff(x, residual) return x class SelfTransformer(nn.Module): def __init__(self, input_size: int = int(1024), num_heads=1, dropout=0.1): super(SelfTransformer, self).__init__() self.att = torch.nn.MultiheadAttention( input_size, num_heads, dropout, bias=True, batch_first=True ) self.norm1 = nn.LayerNorm(input_size) self.fcl = nn.Linear(input_size, input_size) self.norm2 = nn.LayerNorm(input_size) def forward(self, video): represent, _ = self.att(video, video, video) represent_norm = self.norm1(video + represent) represent_fcl = self.fcl(represent_norm) represent = self.norm1(represent_norm + represent_fcl) return represent class SmallClassificationHead(nn.Module): """ClassificationHead""" def __init__(self, input_size=256, out_emo=6, out_sen=3): super(SmallClassificationHead, self).__init__() self.fc_emo = nn.Linear(input_size, out_emo) self.fc_sen = nn.Linear(input_size, out_sen) def forward(self, x): x_emo = self.fc_emo(x) x_sen = self.fc_sen(x) return {"emo": x_emo, "sen": x_sen} class AudioModelWT(Wav2Vec2PreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.wav2vec2 = Wav2Vec2Model(config) self.f_size = 1024 self.tl1 = TransformerLayer( input_dim=self.f_size, num_heads=4, dropout=0.1, positional_encoding=True ) self.tl2 = TransformerLayer( input_dim=self.f_size, num_heads=4, dropout=0.1, positional_encoding=True ) self.fc1 = nn.Linear(1024, 1) self.dp = nn.Dropout(p=0.5) self.selu = nn.SELU() self.relu = nn.ReLU() self.cl_head = SmallClassificationHead( input_size=199, out_emo=config.out_emo, out_sen=config.out_sen ) self.init_weights() # freeze conv self.freeze_feature_encoder() def freeze_feature_encoder(self): for param in self.wav2vec2.feature_extractor.conv_layers.parameters(): param.requires_grad = False def forward(self, x, with_features=False): outputs = self.wav2vec2(x) x = self.tl1(outputs[0], outputs[0], outputs[0]) x = self.selu(x) features = self.tl2(x, x, x) x = self.selu(features) x = self.fc1(x) x = self.relu(x) x = self.dp(x) x = x.view(x.size(0), -1) if with_features: return self.cl_head(x), features else: return self.cl_head(x) class AudioFeatureExtractor: def __init__( self, checkpoint_url: str, folder_path: str, device: torch.device, sr: int = 16000, win_max_length: int = 4, with_features: bool = True, ) -> None: """ Args: sr (int, optional): Sample rate of audio. Defaults to 16000. win_max_length (int, optional): Max length of window. Defaults to 4. with_features (bool, optional): Extract features or not """ self.device = device self.sr = sr self.win_max_length = win_max_length self.with_features = with_features checkpoint_path = load_model(checkpoint_url, folder_path) model_name = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" model_config = AutoConfig.from_pretrained(model_name) model_config.out_emo = 7 model_config.out_sen = 3 model_config.context_length = 199 self.processor = Wav2Vec2Processor.from_pretrained(model_name) self.model = AudioModelWT.from_pretrained( pretrained_model_name_or_path=model_name, config=model_config ) checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.model.to(self.device) def preprocess_wave(self, x: torch.Tensor) -> torch.Tensor: """Extracts features for wav2vec Apply padding to max length of audio Args: x (torch.Tensor): Input data Returns: np.ndarray: Preprocessed data """ a_data = self.processor( x, sampling_rate=self.sr, return_tensors="pt", padding="max_length", max_length=self.sr * self.win_max_length, ) return a_data["input_values"][0] def __call__( self, waveform: torch.Tensor ) -> tuple[dict[torch.Tensor], torch.Tensor]: """Extracts acoustic features Apply padding to max length of audio Args: wave (torch.Tensor): wave Returns: torch.Tensor: Extracted features """ waveform = self.preprocess_wave(waveform).unsqueeze(0).to(self.device) with torch.no_grad(): if self.with_features: preds, features = self.model(waveform, with_features=self.with_features) else: preds = self.model(waveform, with_features=self.with_features) predicts = { "emo": F.softmax(preds["emo"], dim=-1).detach().cpu().squeeze(), "sen": F.softmax(preds["sen"], dim=-1).detach().cpu().squeeze(), } return ( (predicts, features.detach().cpu().squeeze()) if self.with_features else (predicts, None) ) class Tmodel(nn.Module): def __init__( self, input_size: int = int(1024), activation=nn.SELU(), feature_size1=256, feature_size2=64, num_heads=1, num_layers=2, n_emo=7, n_sent=3, ): super(Tmodel, self).__init__() self.feature_text_dynamic = nn.ModuleList( [ SelfTransformer(input_size=input_size, num_heads=num_heads) for i in range(num_layers) ] ) self.fcl = nn.Linear(input_size, feature_size1) self.activation = activation self.feature_emo = nn.Linear(feature_size1, feature_size2) self.feature_sent = nn.Linear(feature_size1, feature_size2) self.fc_emo = nn.Linear(feature_size2, n_emo) self.fc_sent = nn.Linear(feature_size2, n_sent) def get_features(self, t): for i, l in enumerate(self.feature_text_dynamic): self.features = l(t) def forward(self, t): self.get_features(t) represent = self.activation(torch.mean(t, axis=1)) represent = self.activation(self.fcl(represent)) represent_emo = self.activation(self.feature_emo(represent)) represent_sent = self.activation(self.feature_sent(represent)) prob_emo = self.fc_emo(represent_emo) prob_sent = self.fc_sent(represent_sent) return prob_emo, prob_sent class TextFeatureExtractor: def __init__( self, checkpoint_url: str, folder_path: str, device: torch.device, with_features: bool = True, ) -> None: self.device = device self.with_features = with_features model_name_bert = "julian-schelb/roberta-ner-multilingual" self.tokenizer = AutoTokenizer.from_pretrained( model_name_bert, add_prefix_space=True ) self.model_bert = AutoModel.from_pretrained(model_name_bert) checkpoint_path = load_model(checkpoint_url, folder_path) self.model = Tmodel() self.model.load_state_dict( torch.load(checkpoint_path, map_location=self.device) ) self.model.to(self.device) def preprocess_text(self, text: torch.Tensor) -> torch.Tensor: if text != "" and str(text) != "nan": inputs = self.tokenizer( text.lower(), padding="max_length", truncation="longest_first", return_tensors="pt", max_length=6, ).to(self.device) with torch.no_grad(): self.model_bert = self.model_bert.to(self.device) outputs = ( self.model_bert( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ) .last_hidden_state.cpu() .detach() ) else: outputs = torch.zeros((1, 6, 1024)) return outputs def __call__(self, text: torch.Tensor) -> tuple[dict[torch.Tensor], torch.Tensor]: text_features = self.preprocess_text(text) with torch.no_grad(): if self.with_features: pred_emo, pred_sent = self.model(text_features.float().to(self.device)) temporal_features = self.model.features else: pred_emo, pred_sent = self.model(text_features.float().to(self.device)) predicts = { "emo": F.softmax(pred_emo, dim=-1).detach().cpu().squeeze(), "sen": F.softmax(pred_sent, dim=-1).detach().cpu().squeeze(), } return ( (predicts, temporal_features.detach().cpu().squeeze()) if self.with_features else (predicts, None) ) class Bottleneck(nn.Module): expansion = 4 def __init__(self, in_channels, out_channels, i_downsample=None, stride=1): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False, ) self.batch_norm1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size=3, padding="same", bias=False ) self.batch_norm2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99) self.conv3 = nn.Conv2d( out_channels, out_channels * self.expansion, kernel_size=1, stride=1, padding=0, bias=False, ) self.batch_norm3 = nn.BatchNorm2d( out_channels * self.expansion, eps=0.001, momentum=0.99 ) self.i_downsample = i_downsample self.stride = stride self.relu = nn.ReLU() def forward(self, x): identity = x.clone() x = self.relu(self.batch_norm1(self.conv1(x))) x = self.relu(self.batch_norm2(self.conv2(x))) x = self.conv3(x) x = self.batch_norm3(x) # downsample if needed if self.i_downsample is not None: identity = self.i_downsample(identity) # add identity x += identity x = self.relu(x) return x class Conv2dSame(torch.nn.Conv2d): def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int: return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: ih, iw = x.size()[-2:] pad_h = self.calc_same_pad( i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0] ) pad_w = self.calc_same_pad( i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1] ) if pad_h > 0 or pad_w > 0: x = F.pad( x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] ) return F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, ) class ResNet(nn.Module): def __init__(self, ResBlock, layer_list, num_classes, num_channels=3): super(ResNet, self).__init__() self.in_channels = 64 self.conv_layer_s2_same = Conv2dSame( num_channels, 64, 7, stride=2, groups=1, bias=False ) self.batch_norm1 = nn.BatchNorm2d(64, eps=0.001, momentum=0.99) self.relu = nn.ReLU() self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2) self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64, stride=1) self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2) self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2) self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc1 = nn.Linear(512 * ResBlock.expansion, 512) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(512, num_classes) def extract_features_four(self, x): x = self.relu(self.batch_norm1(self.conv_layer_s2_same(x))) x = self.max_pool(x) # print(x.shape) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x def extract_features(self, x): x = self.extract_features_four(x) x = self.avgpool(x) x = x.reshape(x.shape[0], -1) x = self.fc1(x) return x def forward(self, x): x = self.extract_features(x) x = self.relu1(x) x = self.fc2(x) return x def _make_layer(self, ResBlock, blocks, planes, stride=1): ii_downsample = None layers = [] if stride != 1 or self.in_channels != planes * ResBlock.expansion: ii_downsample = nn.Sequential( nn.Conv2d( self.in_channels, planes * ResBlock.expansion, kernel_size=1, stride=stride, bias=False, padding=0, ), nn.BatchNorm2d(planes * ResBlock.expansion, eps=0.001, momentum=0.99), ) layers.append( ResBlock( self.in_channels, planes, i_downsample=ii_downsample, stride=stride ) ) self.in_channels = planes * ResBlock.expansion for i in range(blocks - 1): layers.append(ResBlock(self.in_channels, planes)) return nn.Sequential(*layers) def ResNet50(num_classes, channels=3): return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, channels) class Vmodel(nn.Module): def __init__( self, input_size=512, activation=nn.SELU(), feature_size=64, num_heads=1, num_layers=1, positional_encoding=False, n_emo=7, n_sent=3, ): super(Vmodel, self).__init__() self.feature_video_dynamic = nn.ModuleList( [ TransformerLayer( input_dim=input_size, num_heads=num_heads, positional_encoding=positional_encoding, ) for i in range(num_layers) ] ) self.fcl = nn.Linear(input_size, feature_size) self.activation = activation self.feature_emo = nn.Linear(feature_size, feature_size) self.feature_sent = nn.Linear(feature_size, feature_size) self.fc_emo = nn.Linear(feature_size, n_emo) self.fc_sent = nn.Linear(feature_size, n_sent) def forward(self, x, with_features=False): for i, l in enumerate(self.feature_video_dynamic): x = l(x, x, x) represent = self.activation(torch.mean(x, axis=1)) represent = self.activation(self.fcl(represent)) represent_emo = self.activation(self.feature_emo(represent)) represent_sent = self.activation(self.feature_sent(represent)) prob_emo = self.fc_emo(represent_emo) prob_sent = self.fc_sent(represent_sent) if with_features: return {"emo": prob_emo, "sen": prob_sent}, x else: return {"emo": prob_emo, "sen": prob_sent} class VideoModelLoader: def __init__( self, face_checkpoint_url: str, emotion_checkpoint_url: str, emo_sent_checkpoint_url: str, folder_path: str, device: torch.device, ) -> None: self.device = device # YOLO face recognition model initialization face_model_path = load_model(face_checkpoint_url, folder_path) emotion_video_model_path = load_model(emotion_checkpoint_url, folder_path) emo_sent_video_model_path = load_model(emo_sent_checkpoint_url, folder_path) self.face_model = YOLO(face_model_path) # EmoAffectet model initialization (static model) self.emo_affectnet_model = ResNet50(num_classes=7, channels=3) self.emo_affectnet_model.load_state_dict( torch.load(emotion_video_model_path, map_location=self.device) ) self.emo_affectnet_model.to(self.device).eval() # Visual emotion and sentiment recognition model (dynamic model) self.emo_sent_video_model = Vmodel() self.emo_sent_video_model.load_state_dict( torch.load(emo_sent_video_model_path, map_location=self.device) ) self.emo_sent_video_model.to(self.device).eval() def extract_zeros_features(self): zeros = torch.unsqueeze(torch.zeros((3, 224, 224)), 0).to(self.device) zeros_features = self.emo_affectnet_model.extract_features(zeros) return zeros_features.cpu().detach().numpy()[0] class VideoFeatureExtractor: def __init__( self, model_loader: VideoModelLoader, file_path: str, target_fps: int = 5, with_features: bool = True, ) -> None: self.model_loader = model_loader self.with_features = with_features # Video options self.cap = cv2.VideoCapture(file_path) self.w, self.h, self.fps, self.frame_number = ( int(self.cap.get(x)) for x in ( cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS, cv2.CAP_PROP_FRAME_COUNT, ) ) self.dur = self.frame_number / self.fps self.target_fps = target_fps self.frame_interval = int(self.fps / target_fps) # Extract zero features if no face found in frame self.zeros_features = self.model_loader.extract_zeros_features() # Dictionaries with facial features and faces self.facial_features = {} self.faces = {} def preprocess_frame(self, frame: np.ndarray, counter: int) -> None: curr_fr = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = self.model_loader.face_model.track( curr_fr, persist=True, imgsz=640, conf=0.01, iou=0.5, augment=False, device=self.model_loader.device, verbose=False, ) need_features = np.zeros(512) count_face = 0 if results[0].boxes.xyxy.cpu().tolist() != []: for i in results[0].boxes: idx_box = i.id.int().cpu().tolist()[0] if i.id else -1 box = i.xyxy.int().cpu().tolist()[0] startX, startY = max(0, box[0]), max(0, box[1]) endX, endY = min(self.w - 1, box[2]), min(self.h - 1, box[3]) face_region = curr_fr[startY:endY, startX:endX] norm_face_region = pth_processing(Image.fromarray(face_region)) with torch.no_grad(): curr_features = ( self.model_loader.emo_affectnet_model.extract_features( norm_face_region.to(self.model_loader.device) ) ) need_features += curr_features.cpu().detach().numpy()[0] count_face += 1 # face_region = cv2.resize(face_region, (224,224), interpolation = cv2.INTER_AREA) # face_region = display_frame_info(face_region, 'Frame: {}'.format(count_face), box_scale=.3) if idx_box in self.faces: self.faces[idx_box].update({counter: face_region}) else: self.faces[idx_box] = {counter: face_region} need_features /= count_face self.facial_features[counter] = need_features else: if counter - 1 in self.facial_features: self.facial_features[counter] = self.facial_features[counter - 1] else: self.facial_features[counter] = self.zeros_features def preprocess_video(self) -> None: counter = 0 while True: ret, frame = self.cap.read() if not ret: break if counter % self.frame_interval == 0: self.preprocess_frame(frame, counter) counter += 1 def __call__( self, window: dict, win_max_length: int, sr: int = 16000 ) -> tuple[dict[torch.Tensor], torch.Tensor]: curr_idx_frames = get_idx_frames_in_windows( list(self.facial_features.keys()), window, self.fps, sr ) video_features = np.array(list(self.facial_features.values())) curr_features = video_features[curr_idx_frames, :] if len(curr_features) < self.target_fps * win_max_length: diff = self.target_fps * win_max_length - len(curr_features) curr_features = np.concatenate( [curr_features, [curr_features[-1]] * diff], axis=0 ) curr_features = ( torch.FloatTensor(curr_features).unsqueeze(0).to(self.model_loader.device) ) with torch.no_grad(): if self.with_features: preds, features = self.model_loader.emo_sent_video_model( curr_features, with_features=self.with_features ) else: preds = self.model_loader.emo_sent_video_model( curr_features, with_features=self.with_features ) predicts = { "emo": F.softmax(preds["emo"], dim=-1).detach().cpu().squeeze(), "sen": F.softmax(preds["sen"], dim=-1).detach().cpu().squeeze(), } return ( (predicts, features.detach().cpu().squeeze()) if self.with_features else (predicts, None) )