import gradio as gr def classify(input_img): from transformers import ( AutoModelForSequenceClassification, LayoutLMv2FeatureExtractor, LayoutLMv2Tokenizer, LayoutLMv2Processor, ) model = AutoModelForSequenceClassification.from_pretrained( "fedihch/InvoiceReceiptClassifier" ) feature_extractor = LayoutLMv2FeatureExtractor() tokenizer = LayoutLMv2Tokenizer.from_pretrained("microsoft/layoutlmv2-base-uncased") processor = LayoutLMv2Processor(feature_extractor, tokenizer) encoded_inputs = processor(input_img, return_tensors="pt") for k, v in encoded_inputs.items(): encoded_inputs[k] = v.to(model.device) outputs = model(**encoded_inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() id2label = {0: "invoice", 1: "receipt"} return id2label[predicted_class_idx] demo = gr.Interface( fn=classify, inputs=gr.Image(shape=(200, 200)), outputs="text", allow_flagging="manual", ) demo.launch(share=True)