waveydaveygravy commited on
Commit
b629419
β€’
1 Parent(s): 1dda0a5

Upload vid2cn2vid.ipynb

Browse files
Files changed (1) hide show
  1. vid2cn2vid.ipynb +726 -0
vid2cn2vid.ipynb ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {
21
+ "id": "asNLOn0uIC5o"
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "### based on https://github.com/patrickvonplaten/controlnet_aux\n",
26
+ "### which is derived from https://github.com/lllyasviel/ControlNet/tree/main/annotator and connected to the πŸ€— Hub.\n",
27
+ "\n",
28
+ "#All credit & copyright goes to https://github.com/lllyasviel .\n",
29
+ "#some of the models are large comment them out to save space if not needed"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {
36
+ "id": "qbM01EucvW58"
37
+ },
38
+ "outputs": [],
39
+ "source": [
40
+ "!pip install controlnet-aux==0.0.7\n",
41
+ "!pip install -U openmim\n",
42
+ "!pip install cog\n",
43
+ "!pip install mediapipe\n",
44
+ "!mim install mmengine\n",
45
+ "!mim install \"mmcv>=2.0.1\"\n",
46
+ "!mim install \"mmdet>=3.1.0\"\n",
47
+ "!mim install \"mmpose>=1.1.0\"\n",
48
+ "!pip install moviepy\n",
49
+ "!pip install argparse\n",
50
+ "\n",
51
+ "import os\n",
52
+ "\n",
53
+ "# Create the directory /content/test\n",
54
+ "os.makedirs(\"/content/test\", exist_ok=True)\n",
55
+ "\n",
56
+ "# Create the directory /content/frames\n",
57
+ "os.makedirs(\"/content/frames\", exist_ok=True)\n"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "source": [
63
+ "from google.colab import files\n",
64
+ "uploaded = files.upload()"
65
+ ],
66
+ "metadata": {
67
+ "id": "fy-P7QkwCMBd"
68
+ },
69
+ "execution_count": null,
70
+ "outputs": []
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "source": [
75
+ "#@title break video down into frames\n",
76
+ "import cv2\n",
77
+ "\n",
78
+ "# Open the video file\n",
79
+ "cap = cv2.VideoCapture('/content/a.mp4')\n",
80
+ "\n",
81
+ "i = 0\n",
82
+ "while(cap.isOpened()):\n",
83
+ " ret, frame = cap.read()\n",
84
+ "\n",
85
+ " if ret == False:\n",
86
+ " break\n",
87
+ "\n",
88
+ " # Save each frame of the video\n",
89
+ " cv2.imwrite('/content/frames/frame_' + str(i) + '.jpg', frame)\n",
90
+ "\n",
91
+ " i += 1\n",
92
+ "\n",
93
+ "cap.release()\n",
94
+ "cv2.destroyAllWindows()"
95
+ ],
96
+ "metadata": {
97
+ "id": "Kw0hIeYnvjLV"
98
+ },
99
+ "execution_count": null,
100
+ "outputs": []
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "source": [
105
+ "###COMMENT OUT PROCESSORS YOU DONT WANT TO USE ALSO COMMENT OUT ONES WITH LARGE MODELS IF YOU WANT TO SAVE SPACE\n",
106
+ "### based on https://github.com/patrickvonplaten/controlnet_aux\n",
107
+ "### which is derived from https://github.com/lllyasviel/ControlNet/tree/main/annotator and connected to the πŸ€— Hub.\n",
108
+ "#All credit & copyright goes to https://github.com/lllyasviel .\n",
109
+ "#some of the models are large comment them out to save space if not needed\n",
110
+ "\n",
111
+ "import torch\n",
112
+ "import os\n",
113
+ "import shutil\n",
114
+ "from PIL import Image\n",
115
+ "from tqdm import tqdm\n",
116
+ "from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,\n",
117
+ " LeresDetector, LineartAnimeDetector,\n",
118
+ " LineartDetector, MediapipeFaceDetector,\n",
119
+ " MidasDetector, MLSDdetector, NormalBaeDetector,\n",
120
+ " OpenposeDetector, PidiNetDetector, SamDetector,\n",
121
+ " ZoeDetector, DWposeDetector)\n",
122
+ "\n",
123
+ "# Create the directory /content/test\n",
124
+ "os.makedirs(\"/content/test\", exist_ok=True)\n",
125
+ "\n",
126
+ "INPUT_DIR = \"/content/frames\" # replace with your input directory\n",
127
+ "OUTPUT_DIR = \"/content/test\" # replace with your output directory\n",
128
+ "\n",
129
+ "# Check if CUDA is available and set the device accordingly\n",
130
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
131
+ "\n",
132
+ "\n",
133
+ "def output(filename, img):\n",
134
+ " img.save(os.path.join(OUTPUT_DIR, filename))\n",
135
+ "\n",
136
+ "def process_image(processor, img):\n",
137
+ " return processor(img)\n",
138
+ "\n",
139
+ "def load_images():\n",
140
+ " if os.path.exists(OUTPUT_DIR):\n",
141
+ " shutil.rmtree(OUTPUT_DIR)\n",
142
+ " os.mkdir(OUTPUT_DIR)\n",
143
+ " images = []\n",
144
+ " filenames = []\n",
145
+ " for filename in os.listdir(INPUT_DIR):\n",
146
+ " if filename.endswith(\".png\") or filename.endswith(\".jpg\"):\n",
147
+ " img_path = os.path.join(INPUT_DIR, filename)\n",
148
+ " img = Image.open(img_path).convert(\"RGB\").resize((512, 512))\n",
149
+ " images.append(img)\n",
150
+ " filenames.append(filename)\n",
151
+ " return images, filenames\n",
152
+ "\n",
153
+ "def process_images(processor):\n",
154
+ " images, filenames = load_images()\n",
155
+ " for img, filename in tqdm(zip(images, filenames), total=len(images), desc=\"Processing images\"):\n",
156
+ " output_img = process_image(processor, img)\n",
157
+ " output(filename, output_img)\n",
158
+ "\n",
159
+ "# Initialize the detectors\n",
160
+ "\n",
161
+ "canny = CannyDetector()\n",
162
+ "hed = HEDdetector.from_pretrained(\"lllyasviel/Annotators\")\n",
163
+ "shuffle = ContentShuffleDetector()\n",
164
+ "leres = LeresDetector.from_pretrained(\"lllyasviel/Annotators\")\n",
165
+ "lineart_anime = LineartAnimeDetector.from_pretrained(\"lllyasviel/Annotators\")\n",
166
+ "lineart = LineartDetector.from_pretrained(\"lllyasviel/Annotators\")\n",
167
+ "mediapipe_face = MediapipeFaceDetector()\n",
168
+ "midas = MidasDetector.from_pretrained(\"lllyasviel/Annotators\")\n",
169
+ "mlsd = MLSDdetector.from_pretrained(\"lllyasviel/Annotators\")\n",
170
+ "normal_bae = NormalBaeDetector.from_pretrained(\"lllyasviel/Annotators\")\n",
171
+ "openpose = OpenposeDetector.from_pretrained(\"lllyasviel/Annotators\")\n",
172
+ "pidi_net = PidiNetDetector.from_pretrained(\"lllyasviel/Annotators\")\n",
173
+ "sam = SamDetector.from_pretrained(\"ybelkada/segment-anything\", subfolder=\"checkpoints\")\n",
174
+ "#zoe = ZoeDetector.from_pretrained(\"lllyasviel/Annotators\")\n",
175
+ "#dwpose = DWposeDetector()\n",
176
+ "\n",
177
+ "\n",
178
+ "\n",
179
+ "# Run the image processing\n",
180
+ "# Uncomment the line for the detector you want to use\n",
181
+ "#process_images(canny)\n",
182
+ "#process_images(hed)\n"
183
+ ],
184
+ "metadata": {
185
+ "colab": {
186
+ "base_uri": "https://localhost:8080/"
187
+ },
188
+ "outputId": "46d65432-5661-4377-ab34-64e5767f6e91",
189
+ "id": "pXgCvJvi45mo"
190
+ },
191
+ "execution_count": null,
192
+ "outputs": [
193
+ {
194
+ "output_type": "stream",
195
+ "name": "stderr",
196
+ "text": [
197
+ "/usr/local/lib/python3.10/dist-packages/timm/models/_factory.py:117: UserWarning: Mapping deprecated model name vit_base_resnet50_384 to current vit_base_r50_s16_384.orig_in21k_ft_in1k.\n",
198
+ " model = create_fn(\n"
199
+ ]
200
+ },
201
+ {
202
+ "output_type": "stream",
203
+ "name": "stdout",
204
+ "text": [
205
+ "Loading base model ()...Done.\n",
206
+ "Removing last two layers (global_pool & classifier).\n"
207
+ ]
208
+ },
209
+ {
210
+ "output_type": "stream",
211
+ "name": "stderr",
212
+ "text": [
213
+ "Processing images: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:14<00:00, 2.02s/it]\n"
214
+ ]
215
+ }
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "source": [
221
+ "#command line version (may need extra work)\n",
222
+ "!python /content/test.py --processor hed --use_cuda --output_dir /content/test/"
223
+ ],
224
+ "metadata": {
225
+ "colab": {
226
+ "base_uri": "https://localhost:8080/"
227
+ },
228
+ "id": "lsJnu9BiJbId",
229
+ "outputId": "c68d113f-27bc-4bf0-9c04-625b1fce6aa5"
230
+ },
231
+ "execution_count": 1,
232
+ "outputs": [
233
+ {
234
+ "output_type": "stream",
235
+ "name": "stdout",
236
+ "text": [
237
+ "python3: can't open file '/content/test.py': [Errno 2] No such file or directory\n"
238
+ ]
239
+ }
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "source": [
245
+ "### COMMAND LINE VERSION test.py\n",
246
+ "# based on https://github.com/patrickvonplaten/controlnet_aux\n",
247
+ "### which is derived from https://github.com/lllyasviel/ControlNet/tree/main/annotator and connected to the πŸ€— Hub.\n",
248
+ "\n",
249
+ "#All credit & copyright goes to https://github.com/lllyasviel .\n",
250
+ "#some of the models are large comment them out to save space if not needed\n",
251
+ "import torch\n",
252
+ "import argparse\n",
253
+ "import os\n",
254
+ "import shutil\n",
255
+ "from PIL import Image\n",
256
+ "from tqdm import tqdm\n",
257
+ "from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,\n",
258
+ " LeresDetector, LineartAnimeDetector,\n",
259
+ " LineartDetector, MediapipeFaceDetector,\n",
260
+ " MidasDetector, MLSDdetector, NormalBaeDetector,\n",
261
+ " OpenposeDetector, PidiNetDetector, SamDetector,\n",
262
+ " ZoeDetector, DWposeDetector)\n",
263
+ "\n",
264
+ "# Create the directory /content/test\n",
265
+ "os.makedirs(\"/content/test\", exist_ok=True)\n",
266
+ "\n",
267
+ "INPUT_DIR = \"/content/frames\" # replace with your input directory\n",
268
+ "OUTPUT_DIR = \"/content/test\" # replace with your output directory\n",
269
+ "\n",
270
+ "# Check if CUDA is available and set the device accordingly\n",
271
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
272
+ "\n",
273
+ "def output(filename, img):\n",
274
+ " img.save(os.path.join(OUTPUT_DIR, filename))\n",
275
+ "\n",
276
+ "def process_image(processor, img):\n",
277
+ " return processor(img)\n",
278
+ "\n",
279
+ "def load_images():\n",
280
+ " if os.path.exists(OUTPUT_DIR):\n",
281
+ " shutil.rmtree(OUTPUT_DIR)\n",
282
+ " os.mkdir(OUTPUT_DIR)\n",
283
+ " images = []\n",
284
+ " filenames = []\n",
285
+ " for filename in os.listdir(INPUT_DIR):\n",
286
+ " if filename.endswith(\".png\") or filename.endswith(\".jpg\"):\n",
287
+ " img_path = os.path.join(INPUT_DIR, filename)\n",
288
+ " img = Image.open(img_path).convert(\"RGB\").resize((512, 512))\n",
289
+ " images.append(img)\n",
290
+ " filenames.append(filename)\n",
291
+ " return images, filenames\n",
292
+ "\n",
293
+ "def process_images(processor):\n",
294
+ " images, filenames = load_images()\n",
295
+ " for img, filename in tqdm(zip(images, filenames), total=len(images), desc=\"Processing images\"):\n",
296
+ " output_img = process_image(processor, img)\n",
297
+ " output(filename, output_img)\n",
298
+ "\n",
299
+ "\n",
300
+ "\n",
301
+ "# Initialize the argument parser\n",
302
+ "parser = argparse.ArgumentParser(description='Choose a processor to run.')\n",
303
+ "parser.add_argument('--processor', type=str, help='The name of the processor to run.')\n",
304
+ "parser.add_argument('--use_cuda', action='store_true', help='Use CUDA if available.')\n",
305
+ "parser.add_argument('--output_dir', type=str, default='./', help='The directory to save the output.')\n",
306
+ "# Parse the arguments\n",
307
+ "args = parser.parse_args()\n",
308
+ "\n",
309
+ "# Check if CUDA is available and set the device accordingly\n",
310
+ "device = torch.device(\"cuda\" if args.use_cuda and torch.cuda.is_available() else \"cpu\")\n",
311
+ "\n",
312
+ "\n",
313
+ "# Initialize the detectors\n",
314
+ "detectors = {\n",
315
+ " 'canny': CannyDetector(),\n",
316
+ " 'hed': HEDdetector.from_pretrained(\"lllyasviel/Annotators\"),\n",
317
+ " 'shuffle': ContentShuffleDetector(),\n",
318
+ " 'leres': LeresDetector.from_pretrained(\"lllyasviel/Annotators\"),\n",
319
+ " 'lineart_anime': LineartAnimeDetector.from_pretrained(\"lllyasviel/Annotators\"),\n",
320
+ " 'lineart': LineartDetector.from_pretrained(\"lllyasviel/Annotators\"),\n",
321
+ " 'mediapipe_face': MediapipeFaceDetector(),\n",
322
+ " 'midas': MidasDetector.from_pretrained(\"lllyasviel/Annotators\"),\n",
323
+ " 'mlsd': MLSDdetector.from_pretrained(\"lllyasviel/Annotators\"),\n",
324
+ " 'normal_bae': NormalBaeDetector.from_pretrained(\"lllyasviel/Annotators\"),\n",
325
+ " 'openpose': OpenposeDetector.from_pretrained(\"lllyasviel/Annotators\"),\n",
326
+ " 'pidi_net': PidiNetDetector.from_pretrained(\"lllyasviel/Annotators\"),\n",
327
+ " 'sam': SamDetector.from_pretrained(\"ybelkada/segment-anything\", subfolder=\"checkpoints\"),\n",
328
+ " # 'zoe': ZoeDetector.from_pretrained(\"lllyasviel/Annotators\"),\n",
329
+ " # 'dwpose': DWposeDetector(),\n",
330
+ "}\n",
331
+ "\n",
332
+ "# Run the chosen processor\n",
333
+ "if args.processor in detectors:\n",
334
+ " detector = detectors[args.processor]\n",
335
+ " # Run your code here with the chosen detector\n",
336
+ "else:\n",
337
+ " print(f\"Unknown processor: {args.processor}\")\n"
338
+ ],
339
+ "metadata": {
340
+ "id": "8YYwMuMpJoKB"
341
+ },
342
+ "execution_count": null,
343
+ "outputs": []
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "source": [
348
+ "#@title interpolate processed frames (best to keep fps same as input video)\n",
349
+ "!ffmpeg -r 25 -i /content/test/frame_%d_%d.png -start_number 0 -end_number 6 -c:v libx264 -vf \"fps=25,format=yuv420p\" testpose1.mp4\n"
350
+ ],
351
+ "metadata": {
352
+ "id": "8kUk-kFPwzmq"
353
+ },
354
+ "execution_count": null,
355
+ "outputs": []
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "source": [
360
+ "#display video\n",
361
+ "from IPython.display import HTML\n",
362
+ "from base64 import b64encode\n",
363
+ "\n",
364
+ "# Open the video file and read its contents\n",
365
+ "mp4 = open('/content/testpose.mp4', 'rb').read()\n",
366
+ "\n",
367
+ "# Encode the video data as a base64 string\n",
368
+ "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
369
+ "\n",
370
+ "# Display the video using an HTML video element\n",
371
+ "HTML(f\"\"\"\n",
372
+ "<video width=400 controls>\n",
373
+ " <source src=\"{data_url}\" type=\"video/mp4\">\n",
374
+ "</video>\n",
375
+ "\"\"\")"
376
+ ],
377
+ "metadata": {
378
+ "id": "6AKmRPK3J7GO"
379
+ },
380
+ "execution_count": null,
381
+ "outputs": []
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "source": [
386
+ "!zip -r nameof.zip <location of files and folder>"
387
+ ],
388
+ "metadata": {
389
+ "id": "Oax1BHwYTZog"
390
+ },
391
+ "execution_count": null,
392
+ "outputs": []
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": null,
397
+ "metadata": {
398
+ "id": "FaF3RdKdaFa8"
399
+ },
400
+ "outputs": [],
401
+ "source": [
402
+ "#@title Login to HuggingFace πŸ€—\n",
403
+ "\n",
404
+ "#@markdown You need to accept the model license before downloading or using the Stable Diffusion weights. Please, visit the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree. You have to be a registered user in πŸ€— Hugging Face Hub, and you'll also need to use an access token for the code to work.\n",
405
+ "# https://huggingface.co/settings/tokens\n",
406
+ "!mkdir -p ~/.huggingface\n",
407
+ "HUGGINGFACE_TOKEN = \"\" #@param {type:\"string\"}\n",
408
+ "!echo -n \"{HUGGINGFACE_TOKEN}\" > ~/.huggingface/token"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": null,
414
+ "metadata": {
415
+ "id": "aEJZoFQ2YHIb"
416
+ },
417
+ "outputs": [],
418
+ "source": [
419
+ "@#title upload to Huggingface\n",
420
+ "from huggingface_hub import HfApi\n",
421
+ "api = HfApi()\n",
422
+ "api.upload_file(\n",
423
+ " path_or_fileobj=\"\",\n",
424
+ " path_in_repo=\"name.zip\",\n",
425
+ " repo_id=\"\",\n",
426
+ " repo_type=\"dataset\",\n",
427
+ ")"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "source": [],
433
+ "metadata": {
434
+ "id": "lUf1h6FSKlr7"
435
+ },
436
+ "execution_count": null,
437
+ "outputs": []
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "source": [],
442
+ "metadata": {
443
+ "id": "9DOaoGnnKl_M"
444
+ },
445
+ "execution_count": null,
446
+ "outputs": []
447
+ },
448
+ {
449
+ "cell_type": "code",
450
+ "source": [],
451
+ "metadata": {
452
+ "id": "H_iCXpzCKmQl"
453
+ },
454
+ "execution_count": null,
455
+ "outputs": []
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "source": [
460
+ "#@title working FAST batch processing CODE TEMPLATE WIP (just doesnt save as og filenames)\n",
461
+ "\n",
462
+ "import torch\n",
463
+ "import os\n",
464
+ "from typing import List\n",
465
+ "from cog import BasePredictor, Input, Path\n",
466
+ "from PIL import Image\n",
467
+ "from io import BytesIO\n",
468
+ "import time\n",
469
+ "from tqdm import tqdm\n",
470
+ "from controlnet_aux.processor import Processor\n",
471
+ "from controlnet_aux import (\n",
472
+ " HEDdetector,\n",
473
+ " MidasDetector,\n",
474
+ " MLSDdetector,\n",
475
+ " OpenposeDetector,\n",
476
+ " PidiNetDetector,\n",
477
+ " NormalBaeDetector,\n",
478
+ " LineartDetector,\n",
479
+ " LineartAnimeDetector,\n",
480
+ " CannyDetector,\n",
481
+ " ContentShuffleDetector,\n",
482
+ " ZoeDetector,\n",
483
+ " MediapipeFaceDetector,\n",
484
+ " SamDetector,\n",
485
+ " LeresDetector,\n",
486
+ " DWposeDetector,\n",
487
+ ")\n",
488
+ "\n",
489
+ "#Processor = processor\n",
490
+ "image_dir = '/content/frames'\n",
491
+ "\n",
492
+ "class Predictor(BasePredictor):\n",
493
+ " def setup(self) -> None:\n",
494
+ " \"\"\"Load the model into memory to make running multiple predictions efficient\"\"\"\n",
495
+ "\n",
496
+ " self.annotators = {\n",
497
+ " \"canny\": CannyDetector(),\n",
498
+ " \"content\": ContentShuffleDetector(),\n",
499
+ " \"face_detector\": MediapipeFaceDetector(),\n",
500
+ " \"hed\": self.initialize_detector(HEDdetector),\n",
501
+ " \"midas\": self.initialize_detector(MidasDetector),\n",
502
+ " \"mlsd\": self.initialize_detector(MLSDdetector),\n",
503
+ " \"open_pose\": self.initialize_detector(OpenposeDetector),\n",
504
+ " \"pidi\": self.initialize_detector(PidiNetDetector),\n",
505
+ " \"normal_bae\": self.initialize_detector(NormalBaeDetector),\n",
506
+ " \"lineart\": self.initialize_detector(LineartDetector),\n",
507
+ " \"lineart_anime\": self.initialize_detector(LineartAnimeDetector),\n",
508
+ " # \"zoe\": self.initialize_detector(ZoeDetector),\n",
509
+ "\n",
510
+ "\n",
511
+ " # \"mobile_sam\": self.initialize_detector(\n",
512
+ " # SamDetector,\n",
513
+ " # model_name=\"dhkim2810/MobileSAM\",\n",
514
+ " # model_type=\"vit_t\",\n",
515
+ " # filename=\"mobile_sam.pt\",\n",
516
+ " # ),\n",
517
+ " \"leres\": self.initialize_detector(LeresDetector),\n",
518
+ " }\n",
519
+ "\n",
520
+ " torch.device(\"cuda\")\n",
521
+ "\n",
522
+ " def initialize_detector(\n",
523
+ " self, detector_class, model_name=\"lllyasviel/Annotators\", **kwargs\n",
524
+ " ):\n",
525
+ " return detector_class.from_pretrained(\n",
526
+ " model_name,\n",
527
+ " cache_dir=\"model_cache\",\n",
528
+ " **kwargs,\n",
529
+ " )\n",
530
+ "\n",
531
+ " def process_images(self, image_dir: str) -> List[Path]:\n",
532
+ " # Start time for overall processing\n",
533
+ " start_time = time.time()\n",
534
+ "\n",
535
+ " # Load all images into memory\n",
536
+ " images = [Image.open(os.path.join(image_dir, image_name)).convert(\"RGB\").resize((512, 512)) for image_name in os.listdir(image_dir)]\n",
537
+ "\n",
538
+ " paths = []\n",
539
+ "\n",
540
+ " def predict(\n",
541
+ " self,\n",
542
+ " image_dir: str = Input(\n",
543
+ " default=\"/content/frames\",\n",
544
+ " description=\"Directory containing the images to be processed\"\n",
545
+ " )\n",
546
+ "):\n",
547
+ "\n",
548
+ " canny: bool = Input(\n",
549
+ " default=True,\n",
550
+ " description=\"Run canny edge detection\",\n",
551
+ " ),\n",
552
+ " content: bool = Input(\n",
553
+ " default=True,\n",
554
+ " description=\"Run content shuffle detection\",\n",
555
+ " ),\n",
556
+ " face_detector: bool = Input(\n",
557
+ " default=True,\n",
558
+ " description=\"Run face detection\",\n",
559
+ " ),\n",
560
+ " hed: bool = Input(\n",
561
+ " default=True,\n",
562
+ " description=\"Run HED detection\",\n",
563
+ " ),\n",
564
+ " midas: bool = Input(\n",
565
+ " default=True,\n",
566
+ " description=\"Run Midas detection\",\n",
567
+ " ),\n",
568
+ " mlsd: bool = Input(\n",
569
+ " default=True,\n",
570
+ " description=\"Run MLSD detection\",\n",
571
+ " ),\n",
572
+ " open_pose: bool = Input(\n",
573
+ " default=True,\n",
574
+ " description=\"Run Openpose detection\",\n",
575
+ " ),\n",
576
+ " pidi: bool = Input(\n",
577
+ " default=True,\n",
578
+ " description=\"Run PidiNet detection\",\n",
579
+ " ),\n",
580
+ " normal_bae: bool = Input(\n",
581
+ " default=True,\n",
582
+ " description=\"Run NormalBae detection\",\n",
583
+ " ),\n",
584
+ " lineart: bool = Input(\n",
585
+ " default=True,\n",
586
+ " description=\"Run Lineart detection\",\n",
587
+ " ),\n",
588
+ " lineart_anime: bool = Input(\n",
589
+ " default=True,\n",
590
+ " description=\"Run LineartAnime detection\",\n",
591
+ "\n",
592
+ " ),\n",
593
+ " leres: bool = Input(\n",
594
+ " default=True,\n",
595
+ " description=\"Run Leres detection\",\n",
596
+ " ),\n",
597
+ "\n",
598
+ "\n",
599
+ " # Load image\n",
600
+ " # Load all images into memory\n",
601
+ " start_time = time.time() # Start time for overall processing\n",
602
+ " images = [Image.open(os.path.join(image_dir, image_name)).convert(\"RGB\").resize((512, 512)) for image_name in os.listdir(image_dir)]\n",
603
+ "\n",
604
+ " paths = []\n",
605
+ " annotator_inputs = {\n",
606
+ " \"canny\": canny, \"openpose_full\": openpose_full,\n",
607
+ " \"content\": content,\n",
608
+ " \"face_detector\": face_detector,\n",
609
+ " \"hed\": hed,\n",
610
+ " \"midas\": midas,\n",
611
+ " \"mlsd\": mlsd,\n",
612
+ " \"open_pose\": open_pose,\n",
613
+ " \"pidi\": pidi,\n",
614
+ " \"normal_bae\": normal_bae,\n",
615
+ " \"lineart\": lineart,\n",
616
+ " \"lineart_anime\": lineart_anime,\n",
617
+ "\n",
618
+ " \"leres\": leres,\n",
619
+ " }\n",
620
+ " for annotator, run_annotator in annotator_inputs.items():\n",
621
+ " if run_annotator:\n",
622
+ " processed_image = self.process_image(image, annotator)\n",
623
+ " #processed_image.save(f\"/tmp/{annotator}.png\")\n",
624
+ " processed_path = f'/content/test2/{image_name}'\n",
625
+ "\n",
626
+ " return paths\n",
627
+ "\n",
628
+ "import time\n",
629
+ "from tqdm import tqdm\n",
630
+ "\n",
631
+ "# Load images and paths\n",
632
+ "images = []\n",
633
+ "image_paths = []\n",
634
+ "for name in os.listdir(image_dir):\n",
635
+ " path = os.path.join(image_dir, name)\n",
636
+ " image = Image.open(path)\n",
637
+ "\n",
638
+ " images.append(image)\n",
639
+ " image_paths.append(path)\n",
640
+ "\n",
641
+ "# Process images\n",
642
+ "processed = [\n",
643
+ " Processor(\"lineart_anime\") for path in tqdm(image_paths)\n",
644
+ "]\n",
645
+ "\n",
646
+ "# Save processed\n",
647
+ "from PIL import Image\n",
648
+ "\n",
649
+ "# Save processed\n",
650
+ "for name, processor in zip(images, processed):\n",
651
+ "\n",
652
+ " # Process image\n",
653
+ " # Process all images with progress bar\n",
654
+ " processed_images = [processor(image, to_pil=True) for image in tqdm(images, desc=\"Processing images\")]\n",
655
+ "\n",
656
+ " # Save each image\n",
657
+ " for i, img in enumerate(processed_images):\n",
658
+ " processed_path = f'/content/test/{name}_{i}.png'\n",
659
+ " img.save(processed_path)\n",
660
+ "\n",
661
+ "from PIL import Image\n",
662
+ "\n",
663
+ "\n"
664
+ ],
665
+ "metadata": {
666
+ "id": "ajRzOZtiDrGP",
667
+ "colab": {
668
+ "base_uri": "https://localhost:8080/",
669
+ "height": 213,
670
+ "referenced_widgets": [
671
+ "b7207b0dd06849beb14d8c0cdaebcaa0",
672
+ "ce42c14100d342f1a1b929fead2c1d60",
673
+ "8d60c1de06464ac49b383e558e33c8f7",
674
+ "83779c3a8afc4fb4a4b71bb3b4dae8be",
675
+ "ea55fcf91d4346c5820079305f5c4752",
676
+ "2e93ac3132a74f9ea031f94c222230fb",
677
+ "bb87d0010b71413db38634d4f3d7dc9a",
678
+ "d97ff4fe72954aa5891a44e48b7eea35",
679
+ "70fa493960104cf4bc032470ff7f3dcf",
680
+ "2302182b276c4b3b82d23b35001a5893",
681
+ "623b93d070da4478abb4039f722af9ec"
682
+ ]
683
+ },
684
+ "outputId": "d102947c-450c-4dcf-96b1-d8fedf5da525"
685
+ },
686
+ "execution_count": null,
687
+ "outputs": [
688
+ {
689
+ "output_type": "stream",
690
+ "name": "stderr",
691
+ "text": [
692
+ "\r 0%| | 0/7 [00:00<?, ?it/s]"
693
+ ]
694
+ },
695
+ {
696
+ "output_type": "display_data",
697
+ "data": {
698
+ "text/plain": [
699
+ "netG.pth: 0%| | 0.00/218M [00:00<?, ?B/s]"
700
+ ],
701
+ "application/vnd.jupyter.widget-view+json": {
702
+ "version_major": 2,
703
+ "version_minor": 0,
704
+ "model_id": "b7207b0dd06849beb14d8c0cdaebcaa0"
705
+ }
706
+ },
707
+ "metadata": {}
708
+ },
709
+ {
710
+ "output_type": "stream",
711
+ "name": "stderr",
712
+ "text": [
713
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:06<00:00, 1.14it/s]\n",
714
+ "Processing images: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:07<00:00, 1.07s/it]\n",
715
+ "Processing images: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:06<00:00, 1.01it/s]\n",
716
+ "Processing images: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:06<00:00, 1.04it/s]\n",
717
+ "Processing images: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:07<00:00, 1.05s/it]\n",
718
+ "Processing images: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:06<00:00, 1.11it/s]\n",
719
+ "Processing images: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:07<00:00, 1.06s/it]\n",
720
+ "Processing images: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:06<00:00, 1.10it/s]\n"
721
+ ]
722
+ }
723
+ ]
724
+ }
725
+ ]
726
+ }