from transformers import AutoModel, AutoTokenizer, AutoConfig from transformers import PreTrainedModel, PretrainedConfig from transformers import CONFIG_MAPPING, MODEL_MAPPING import torch import torch.nn.functional as F import torch.nn as nn class JinaJudgeConfig(PretrainedConfig): model_type = "jina-judge" def __init__(self, n_classes=3, hidden_dim=512, num_decoder_layers=4, nhead=8, dropout_prob=0.1, **kwargs): super().__init__(**kwargs) self.n_classes = n_classes self.hidden_dim = hidden_dim self.num_decoder_layers = num_decoder_layers self.nhead = nhead self.dropout_prob = dropout_prob class JinaJudge(PreTrainedModel): config_class = JinaJudgeConfig def __init__(self, config: JinaJudgeConfig): super().__init__(config) self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) jina_config = AutoConfig.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True) self.encoder = AutoModel.from_config(jina_config, trust_remote_code=True, torch_dtype=torch.bfloat16) self.encoder.lora_main_params_trainable = True decoder_layer = nn.TransformerDecoderLayer( d_model=self.encoder.config.hidden_size, nhead=config.nhead, dim_feedforward=self.encoder.config.hidden_size, dropout=config.dropout_prob ) self.decoder = nn.TransformerDecoder( decoder_layer, num_layers=config.num_decoder_layers ) self.decoder_input_embedding = nn.Parameter( torch.randn(1, 1, self.encoder.config.hidden_size) ) self.classification_head = nn.Sequential( nn.Linear(self.encoder.config.hidden_size, config.n_classes) ) def forward(self, prompts): inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(self.device) encoder_outputs = self.encoder(**inputs) encoder_hidden_states = encoder_outputs.last_hidden_state.float() encoder_padding_mask = (inputs["attention_mask"] == 0).to(self.device) batch_size = encoder_hidden_states.size(0) decoder_input = self.decoder_input_embedding.expand(1, batch_size, -1).to(self.device) decoder_output = self.decoder( tgt=decoder_input, memory=encoder_hidden_states.transpose(0, 1), memory_key_padding_mask=encoder_padding_mask ).squeeze(0) logits = self.classification_head(decoder_output) return logits AutoConfig.register("jina-judge", JinaJudgeConfig) AutoModel.register(JinaJudgeConfig, JinaJudge)