Spaces:
Runtime error
Runtime error
from models import PoemTextModel | |
from inference import predict_poems_from_text | |
from utils import get_poem_embeddings | |
import config as CFG | |
import json | |
import torch | |
import gradio as gr | |
def greet_user(name): | |
return "Hello " + name + " Welcome to Gradio!π" | |
if __name__ == "__main__": | |
model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device) | |
model.eval() | |
# Inference: Output some example predictions and write them in a file | |
with open('poem_embeddings.json', encoding="utf-8") as f: | |
pe = json.load(f) | |
poem_embeddings = torch.Tensor([p['embeddings'] for p in pe]).to(CFG.device) | |
print(poem_embeddings.shape) | |
poems = [p['beyt'] for p in pe] | |
def gradio_make_predictions(text): | |
beyts = predict_poems_from_text(model, poem_embeddings, text, poems, n=10) | |
return "\n".join(beyts) | |
CFG.batch_size = 512 | |
# print(poem_embeddings[0]) | |
# with open('poem_embeddings.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f: | |
# f.write(json.dumps(poem_embeddings, indent= 4)) | |
text_input = gr.Textbox(label = "Enter the text to find poem beyts for") | |
output = gr.Textbox() | |
app = gr.Interface(fn = gradio_make_predictions, inputs=text_input, outputs=output) | |
app.launch() |