import gradio import wandb import torch from transformers import GPT2Tokenizer,GPT2LMHeadModel from peft import PeftModel import os import re def clean_text(text): # Lowercase the text text = text.lower() # Remove special characters text = re.sub(r'\W', ' ', text) # Remove extra white spaces text = re.sub(r'\s+', ' ', text).strip() return text os.environ["WANDB_API_KEY"] = "d2ad0a7285379c0808ca816971d965fc242d0b5e" wandb.login() run = wandb.init(project="Email_subject_gen", job_type="model_loading") artifact = run.use_artifact('Email_subject_gen/final_model:v0') artifact_dir = artifact.download() #tokenizer= GPT2Tokenizer.from_pretrained(artifact_dir) MODEL_KEY = 'olm/olm-gpt2-dec-2022' tokenizer= GPT2Tokenizer.from_pretrained(MODEL_KEY) tokenizer.add_special_tokens({'pad_token':'{PAD}'}) model = GPT2LMHeadModel.from_pretrained(MODEL_KEY) model.resize_token_embeddings(len(tokenizer)) model.config.dropout = 0.1 # Set dropout rate model.config.attention_dropout = 0.1 model = PeftModel.from_pretrained(model, artifact_dir) def generateSubject(email): clean_text(email) email = "" + clean_text(email) + "" prompts = list() prompts.append(email) tokenizer.padding_side='left' prompts_batch_ids = tokenizer(prompts, padding=True, truncation=True, return_tensors='pt').to(model.device) output_ids = model.generate( **prompts_batch_ids, max_new_tokens=10, pad_token_id=tokenizer.pad_token_id) outputs_batch = [seq.split('')[1] for seq in tokenizer.batch_decode(output_ids, skip_special_tokens=True)] tokenizer.padding_side='right' print(outputs_batch) return outputs_batch[0] def predict(name): return "Hello " + name + "!!" iface = gradio.Interface(fn=generateSubject, inputs="text", outputs="text") iface.launch()