nielsr HF staff commited on
Commit
32a4a70
1 Parent(s): 89bb5fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -8
app.py CHANGED
@@ -8,14 +8,7 @@ torch.hub.download_url_to_file('https://storage.googleapis.com/perceiver_io/dalm
8
  feature_extractor = PerceiverFeatureExtractor()
9
  model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv")
10
 
11
- # define custom pipeline as Perceiver expects "inputs" rather than "pixel_values"
12
- class CustomPipeline(ImageClassificationPipeline):
13
- def _forward(self, model_inputs):
14
- inputs = model_inputs["pixel_values"]
15
- model_outputs = self.model(inputs=inputs)
16
- return model_outputs
17
-
18
- image_pipe = CustomPipeline(model=model, feature_extractor=feature_extractor)
19
 
20
  def classify_image(image):
21
  results = image_pipe(image)
 
8
  feature_extractor = PerceiverFeatureExtractor()
9
  model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv")
10
 
11
+ image_pipe = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
 
 
 
 
 
 
 
12
 
13
  def classify_image(image):
14
  results = image_pipe(image)