File size: 14,277 Bytes
b381617 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"source": [
"!pip install --upgrade tensorflow\n",
"!pip install --upgrade keras\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "U-cy2aDRljAL",
"outputId": "e9efe1eb-0135-432f-9d46-3c39a8013f5b"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: tensorflow in /usr/local/lib/python3.10/dist-packages (2.12.0)\n",
"Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.4.0)\n",
"Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.6.3)\n",
"Requirement already satisfied: flatbuffers>=2.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (23.3.3)\n",
"Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.4.0)\n",
"Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.2.0)\n",
"Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.54.0)\n",
"Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.8.0)\n",
"Requirement already satisfied: jax>=0.3.15 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.4.10)\n",
"Requirement already satisfied: keras<2.13,>=2.12.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.12.0)\n",
"Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (16.0.0)\n",
"Requirement already satisfied: numpy<1.24,>=1.22 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.22.4)\n",
"Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.3.0)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from tensorflow) (23.1)\n",
"Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.20.3)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from tensorflow) (67.7.2)\n",
"Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.16.0)\n",
"Requirement already satisfied: tensorboard<2.13,>=2.12 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.12.2)\n",
"Requirement already satisfied: tensorflow-estimator<2.13,>=2.12.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.12.0)\n",
"Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.3.0)\n",
"Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (4.5.0)\n",
"Requirement already satisfied: wrapt<1.15,>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.14.1)\n",
"Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.32.0)\n",
"Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from astunparse>=1.6.0->tensorflow) (0.40.0)\n",
"Requirement already satisfied: ml-dtypes>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.3.15->tensorflow) (0.1.0)\n",
"Requirement already satisfied: scipy>=1.7 in /usr/local/lib/python3.10/dist-packages (from jax>=0.3.15->tensorflow) (1.10.1)\n",
"Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.13,>=2.12->tensorflow) (2.17.3)\n",
"Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.13,>=2.12->tensorflow) (1.0.0)\n",
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.13,>=2.12->tensorflow) (3.4.3)\n",
"Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.13,>=2.12->tensorflow) (2.27.1)\n",
"Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.13,>=2.12->tensorflow) (0.7.0)\n",
"Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.13,>=2.12->tensorflow) (1.8.1)\n",
"Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.13,>=2.12->tensorflow) (2.3.0)\n",
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow) (5.3.0)\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow) (0.3.0)\n",
"Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow) (4.9)\n",
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.13,>=2.12->tensorflow) (1.3.1)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow) (1.26.15)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow) (2022.12.7)\n",
"Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow) (2.0.12)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow) (3.4)\n",
"Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard<2.13,>=2.12->tensorflow) (2.1.2)\n",
"Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow) (0.5.0)\n",
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.13,>=2.12->tensorflow) (3.2.2)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: keras in /usr/local/lib/python3.10/dist-packages (2.12.0)\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 371
},
"id": "b6wDH2UglAGa",
"outputId": "4049c353-40c4-4784-f064-e32e1678a309"
},
"outputs": [
{
"output_type": "error",
"ename": "AttributeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-5-80a54cfd4831>\u001b[0m in \u001b[0;36m<cell line: 50>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;31m# Load the model weights from a checkpoint.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 50\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVisionTransformer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_classes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 51\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"checkpoints/model.ckpt\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-5-80a54cfd4831>\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, num_classes)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;31m# The transformer encoder consists of a stack of self-attention layers.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m self.transformer_encoder = layers.TransformerEncoder(\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0mnum_heads\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mnum_layers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m12\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: module 'keras.api._v2.keras.layers' has no attribute 'TransformerEncoder'"
]
}
],
"source": [
"import tensorflow as tf\n",
"import tensorflow.keras as keras\n",
"from tensorflow.keras import layers\n",
"\n",
"# Define the model architecture.\n",
"class VisionTransformer(keras.Model):\n",
"\n",
" def __init__(self, num_classes):\n",
" super(VisionTransformer, self).__init__()\n",
"\n",
" # The embedding layer converts each image patch into a vector representation.\n",
" self.embedding = layers.Embedding(\n",
" input_dim=256,\n",
" output_dim=512,\n",
" input_shape=(7, 7, 3),\n",
" trainable=True,\n",
" )\n",
"\n",
" # The transformer encoder consists of a stack of self-attention layers.\n",
" self.transformer_encoder = layers.TransformerEncoder(\n",
" num_heads=8,\n",
" num_layers=12,\n",
" dropout=0.1,\n",
" )\n",
"\n",
" # The classification layer outputs the class probabilities for each image.\n",
" self.classification = layers.Dense(num_classes, activation=\"softmax\")\n",
"\n",
" def call(self, inputs):\n",
" # Extract the image patches from the input image.\n",
" patches = tf.image.extract_patches(\n",
" inputs=inputs,\n",
" size=(7, 7, 3),\n",
" strides=(2, 2, 1),\n",
" padding=\"SAME\",\n",
" )\n",
"\n",
" # Convert the image patches into vector representations.\n",
" embedded_patches = self.embedding(patches)\n",
"\n",
" # Encode the image patches using the transformer encoder.\n",
" encoded_patches = self.transformer_encoder(embedded_patches)\n",
"\n",
" # Classify the image using the classification layer.\n",
" predictions = self.classification(encoded_patches)\n",
"\n",
" return predictions\n",
"\n",
"# Load the model weights from a checkpoint.\n",
"model = VisionTransformer(num_classes=100)\n",
"model.load_weights(\"checkpoints/model.ckpt\")\n",
"\n",
"# Create a video capture object.\n",
"cap = cv2.VideoCapture(0)\n",
"\n",
"# Start a loop to capture frames from the camera and classify them.\n",
"while True:\n",
"\n",
" # Capture a frame from the camera.\n",
" ret, frame = cap.read()\n",
"\n",
" # Convert the frame to a NumPy array.\n",
" frame = np.array(frame)\n",
"\n",
" # Classify the frame.\n",
" predictions = model.predict(frame)\n",
"\n",
" # Display the classification results.\n",
" cv2.putText(frame, \"Prediction: {}\".format(predictions[0]), (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)\n",
"\n",
" # Display the frame.\n",
" cv2.imshow(\"Frame\", frame)\n",
"\n",
" # Wait for a key press.\n",
" key = cv2.waitKey(1) & 0xFF\n",
"\n",
" # If the key `q` is pressed, break out of the loop.\n",
" if key == ord(\"q\"):\n",
" break\n",
"\n",
"# Release the video capture object.\n",
"cap.release()\n",
"\n",
"# Close all open windows.\n",
"cv2.destroyAllWindows()"
]
}
]
} |