Spaces:
Build error
Build error
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
from huggingface_hub import from_pretrained_keras | |
from tensorflow.keras.optimizers import Adam | |
from .constants import LEARNING_RATE | |
def get_model(): | |
""" | |
Download the model from the Hugging Face Hub and compile it. | |
""" | |
model = from_pretrained_keras("keras-io/video-vision-transformer") | |
model.compile( | |
optimizer=Adam(learning_rate=LEARNING_RATE), | |
loss="sparse_categorical_crossentropy", | |
# metrics=[ | |
# keras.metrics.SparseCategoricalAccuracy(name="accuracy"), | |
# keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"), | |
# ], | |
) | |
return model | |
model = get_model() | |
labels = ['liver', 'kidney-right', 'kidney-left', 'femur-right', 'femur-left', 'bladder', 'heart', 'lung-right', 'lung-left', 'spleen', 'pancreas'] | |
def predict_label(path): | |
frames = load_video(path) | |
dataloader = prepare_dataloader(frames) | |
prediction = model.predict(dataloader)[0] | |
label = np.argmax(prediction, axis=0) | |
label = labels[label] | |
return label | |
def load_video(path): | |
""" | |
Load video from path and return a list of frames. | |
The video is converted to grayscale because it is the format expected by the model. | |
""" | |
cap = cv2.VideoCapture(path) | |
frames = [] | |
try: | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
frames.append(frame) | |
finally: | |
cap.release() | |
return np.array(frames) | |
def prepare_dataloader(video): | |
video = tf.expand_dims(video, axis=0) | |
dataset = tf.data.Dataset.from_tensor_slices((video, np.array([0]))) | |
dataloader = ( | |
dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE) | |
.batch(1) | |
.prefetch(tf.data.AUTOTUNE) | |
) | |
return dataloader | |
def preprocess(frames: tf.Tensor, label: tf.Tensor): | |
"""Preprocess the frames tensors and parse the labels.""" | |
# Preprocess images | |
frames = tf.image.convert_image_dtype( | |
frames[ | |
..., tf.newaxis | |
], # The new axis is to help for further processing with Conv3D layers | |
tf.float32, | |
) | |
# Parse label | |
label = tf.cast(label, tf.float32) | |
return frames, label |