import torch import torch.nn as nn import numpy as np from transformers import PreTrainedModel, PretrainedConfig, AutoModelForCausalLM, AutoConfig import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) torch.autograd.set_detect_anomaly(True) class JudgeXLConfig(PretrainedConfig): model_type = "judge-xl" def __init__(self, vocab_size=50276, hidden_size=768, max_len=256, n_layer=12, n_head=12, ff_expansion_factor=4, rnn_units=768, num_labels=5, dropout=0.1, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.max_len = max_len self.n_layer = n_layer self.n_head = n_head self.ff_expansion_factor = ff_expansion_factor self.rnn_units = rnn_units self.num_labels = num_labels self.dropout = dropout self.is_decoder = True class CustomEmbedding(nn.Module): def __init__(self, vocab_size, hidden_size): super(CustomEmbedding, self).__init__() print(f"vocab_size: {vocab_size}, hidden_size: {hidden_size}") # Debugging print assert isinstance(vocab_size, int) and isinstance(hidden_size, int), \ f"Expected integers, but got vocab_size={type(vocab_size)} and hidden_size={type(hidden_size)}" self.embedding = nn.Embedding(vocab_size, hidden_size) def forward(self, inputs): return self.embedding(inputs) class PositionalEncoding(nn.Module): def __init__(self, n_embd, max_len=5000): super(PositionalEncoding, self).__init__() self.n_embd = n_embd self.max_len = max_len pe = torch.zeros(max_len, n_embd) position = torch.arange(0, max_len).unsqueeze(1).float() div_term = torch.exp(torch.arange(0, n_embd, 2).float() * -(np.log(10000.0) / n_embd)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): return x + self.pe[:x.size(0), :] class TransformerXLBlock(nn.Module): def __init__(self, config): super(TransformerXLBlock, self).__init__() self.attn = nn.MultiheadAttention(config.hidden_size, config.n_head, dropout=config.dropout) self.ff = FeedForward(config) self.ln1 = nn.LayerNorm(config.hidden_size) self.ln2 = nn.LayerNorm(config.hidden_size) def forward(self, x, mask=None): attn_out, _ = self.attn(x, x, x, attn_mask=mask) out1 = self.ln1(x + attn_out) ff_out = self.ff(out1) return self.ln2(out1 + ff_out) class FeedForward(nn.Module): def __init__(self, config): super(FeedForward, self).__init__() self.dense1 = nn.Linear(config.hidden_size, config.hidden_size * config.ff_expansion_factor) self.dense2 = nn.Linear(config.hidden_size * config.ff_expansion_factor, config.hidden_size) self.dropout = nn.Dropout(config.dropout) def forward(self, x): x = torch.nn.functional.gelu(self.dense1(x)) x = self.dropout(x) return self.dense2(x) class JudgeXL(PreTrainedModel): config_class = JudgeXLConfig def __init__(self, config): super().__init__(config) self.token_embedding = CustomEmbedding(config.vocab_size, config.hidden_size) self.pos_encoding = PositionalEncoding(config.hidden_size, config.max_len) self.transformer_blocks = nn.ModuleList([TransformerXLBlock(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.hidden_size) self.rnn = nn.LSTM(config.hidden_size, config.rnn_units, num_layers=2, dropout=config.dropout, bidirectional=True, batch_first=True) self.fc = nn.Linear(config.rnn_units * 2, config.vocab_size) self.lm_head = nn.Linear(config.rnn_units, config.vocab_size) self.post_init() def forward(self, x, mask=None): x = self.token_embedding(x) x = self.pos_encoding(x) for block in self.transformer_blocks: x = block(x, mask=mask) x = self.ln_f(x) x, _ = self.rnn(x) x = self.fc(x) x = self.lm_head(x) return x def init_weights(self): """ Initialize weights for your custom layers using PreTrainedModel's default weight initialization method. """ # Hugging Face’s PreTrainedModel has a standard method for initializing weights super().init_weights() def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): if past is None: return {"input_ids": input_ids} else: return {"input_ids": input_ids[:, -1:], "past_key_values": past} def _reorder_cache(self, past, beam_idx): return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) def generate(self, prompt, max_len=100): self.eval() input_ids = self.tokenizer(prompt, return_tensors='pt').input_ids generated = input_ids with torch.no_grad(): for _ in range(max_len): outputs = self.forward(generated) next_token_logits = outputs[:, :] # Adjusted indexing next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) generated = torch.cat((generated, next_token_id), dim=1) if next_token_id.item() == self.tokenizer.sep_token_id: break generated_text = self.tokenizer.decode(generated[0], skip_special_tokens=True) return generated_text config = JudgeXLConfig() model = JudgeXL(config) # Register JudgeXLConfig with AutoConfig JudgeXLConfig.register_for_auto_class(AutoConfig) # Register JudgeXL with AutoModelForCausalLM JudgeXL.register_for_auto_class(AutoModelForCausalLM) model.push_to_hub("Wonder-Griffin/judge-xl-model")