Edit model card
YAML Metadata Warning: The pipeline tag "conversational" is not in the official list: text-classification, token-classification, table-question-answering, question-answering, zero-shot-classification, translation, summarization, feature-extraction, text-generation, text2text-generation, fill-mask, sentence-similarity, text-to-speech, text-to-audio, automatic-speech-recognition, audio-to-audio, audio-classification, voice-activity-detection, depth-estimation, image-classification, object-detection, image-segmentation, text-to-image, image-to-text, image-to-image, image-to-video, unconditional-image-generation, video-classification, reinforcement-learning, robotics, tabular-classification, tabular-regression, tabular-to-text, table-to-text, multiple-choice, text-retrieval, time-series-forecasting, text-to-video, image-text-to-text, visual-question-answering, document-question-answering, zero-shot-image-classification, graph-ml, mask-generation, zero-shot-object-detection, text-to-3d, image-to-3d, image-feature-extraction, video-text-to-text, keypoint-detection, any-to-any, other

Description

DialogGPT is a variant of the GPT (Generative Pretrained Transformer) language model developed by OpenAI. It's a deep neural network-based language model that's trained on massive amounts of text data to generate human-like text.

DialogGPT uses the transformer architecture, which is a type of neural network designed for processing sequential data such as language. During the training phase, the model is exposed to a large corpus of text and learns to predict the next word in a sequence given the previous words.

In the context of dialog, DialogGPT is trained to predict the response in a conversation, given the context of the conversation. This context can include one or more turns of the conversation, along with any additional information such as the topic of the conversation or the speaker's personality.

At inference time, the model takes the current context of the conversation as input and generates a response. The response is generated by sampling from the model's predicted distribution over the vocabulary.

Overall, DialogGPT provides a flexible and powerful solution for generating human-like text in a conversational context, allowing for the creation of a wide range of applications such as chatbots, conversational agents, and virtual assistants

Parameters

Model was trained for 40 epochs, using params as follows.

        per_gpu_train_batch_size: int = 2
        self.per_gpu_eval_batch_size: int = 2
        self.gradient_accumulation_steps: int = 1
        self.learning_rate: float = 5e-5
        self.weight_decay: float = 0.0
        self.adam_epsilon: float = 1e-8
        self.max_grad_norm: int  = 1.0
        self.num_train_epochs: int = 40
        self.max_steps: int = -1
        self.warmup_steps: int = 0
        self.logging_steps: int = 1000
        self.save_steps: int = 3500
        self.save_total_limit = None
        self.eval_all_checkpoints: bool = False
        self.no_cuda: bool = False
        self.overwrite_output_dir: bool = True
        self.overwrite_cache: bool  = True
        self.should_continue: bool = False
        self.seed: int = 42
        self.local_rank: int = -1
        self.fp16: bool = False
        self.fp16_opt_level: str = 'O1'

Usage

DialoGPT small version, finetuned on Morty's sequences (Rick and Morty Cartoon character).

Simple snippet of how to infer of this model:

from transformers import AutoModelWithLMHead, AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('s3nh/DialoGPT-small-morty')
model = AutoModelWithLMHead.from_pretrained('s3nh/DialoGPT-small-morty')

for step in range(4):
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')

    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    chat_history_ids = model.generate(
        bot_input_ids, max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=100, 
        top_p=0.7,
        temperature=0.8
    )
    
    print("MortyBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))


Downloads last month
13
Safetensors
Model size
137M params
Tensor type
F32
·
U8
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Space using s3nh/DialoGPT-small-morty 1