Process images in a directory on the same machine where the server is running." +
+ f"
Use an empty output directory to save pictures normally instead of writing to the output directory." +
+ f"
Add inpaint batch mask directory to enable inpaint batch processing."
+ f"{hidden}
"
+ )
+ img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
+ img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
+ img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
+
+ def copy_image(img):
+ if isinstance(img, dict) and 'image' in img:
+ return img['image']
+
+ return img
+
+ for button, name, elem in copy_image_buttons:
+ button.click(
+ fn=copy_image,
+ inputs=[elem],
+ outputs=[copy_image_destinations[name]],
+ )
+ button.click(
+ fn=lambda: None,
+ _js="switch_to_"+name.replace(" ", "_"),
+ inputs=[],
+ outputs=[],
+ )
+
+ with FormRow():
+ resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
+
+ for category in ordered_ui_categories():
+ if category == "sampler":
+ steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
+
+ elif category == "dimensions":
+ with FormRow():
+ with gr.Column(elem_id="img2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
+
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
+ if opts.dimensions_and_batch_together:
+ with gr.Column(elem_id="img2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
+
+ elif category == "cfg":
+ with FormGroup():
+ with FormRow():
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
+ image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
+
+ elif category == "seed":
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
+
+ elif category == "checkboxes":
+ with FormRow(elem_id="img2img_checkboxes", variant="compact"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
+
+ elif category == "batch":
+ if not opts.dimensions_and_batch_together:
+ with FormRow(elem_id="img2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
+
+ elif category == "override_settings":
+ with FormRow(elem_id="img2img_override_settings_row") as row:
+ override_settings = create_override_settings_dropdown('img2img', row)
+
+ elif category == "scripts":
+ with FormGroup(elem_id="img2img_script_container"):
+ custom_inputs = modules.scripts.scripts_img2img.setup_ui()
+
+ elif category == "inpaint":
+ with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
+ with FormRow():
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
+
+ with FormRow():
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
+
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
+
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
+
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
+
+ def select_img2img_tab(tab):
+ return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
+
+ for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
+ elem.select(
+ fn=lambda tab=i: select_img2img_tab(tab),
+ inputs=[],
+ outputs=[inpaint_controls, mask_alpha],
+ )
+
+ img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
+
+ connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
+ connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
+
+ img2img_prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[
+ img2img_prompt_img
+ ],
+ outputs=[
+ img2img_prompt,
+ img2img_prompt_img
+ ]
+ )
+
+ img2img_args = dict(
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
+ _js="submit_img2img",
+ inputs=[
+ dummy_component,
+ dummy_component,
+ img2img_prompt,
+ img2img_negative_prompt,
+ img2img_prompt_styles,
+ init_img,
+ sketch,
+ init_img_with_mask,
+ inpaint_color_sketch,
+ inpaint_color_sketch_orig,
+ init_img_inpaint,
+ init_mask_inpaint,
+ steps,
+ sampler_index,
+ mask_blur,
+ mask_alpha,
+ inpainting_fill,
+ restore_faces,
+ tiling,
+ batch_count,
+ batch_size,
+ cfg_scale,
+ image_cfg_scale,
+ denoising_strength,
+ seed,
+ subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
+ height,
+ width,
+ resize_mode,
+ inpaint_full_res,
+ inpaint_full_res_padding,
+ inpainting_mask_invert,
+ img2img_batch_input_dir,
+ img2img_batch_output_dir,
+ img2img_batch_inpaint_mask_dir,
+ override_settings,
+ ] + custom_inputs,
+ outputs=[
+ img2img_gallery,
+ generation_info,
+ html_info,
+ html_log,
+ ],
+ show_progress=False,
+ )
+
+ interrogate_args = dict(
+ _js="get_img2img_tab_index",
+ inputs=[
+ dummy_component,
+ img2img_batch_input_dir,
+ img2img_batch_output_dir,
+ init_img,
+ sketch,
+ init_img_with_mask,
+ inpaint_color_sketch,
+ init_img_inpaint,
+ ],
+ outputs=[img2img_prompt, dummy_component],
+ )
+
+ img2img_prompt.submit(**img2img_args)
+ submit.click(**img2img_args)
+ res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
+
+ img2img_interrogate.click(
+ fn=lambda *args: process_interrogate(interrogate, *args),
+ **interrogate_args,
+ )
+
+ img2img_deepbooru.click(
+ fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
+ **interrogate_args,
+ )
+
+ prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
+ style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
+ style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
+
+ for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
+ button.click(
+ fn=add_style,
+ _js="ask_for_style_name",
+ # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
+ # the same number of parameters, but we only know the style-name after the JavaScript prompt
+ inputs=[dummy_component, prompt, negative_prompt],
+ outputs=[txt2img_prompt_styles, img2img_prompt_styles],
+ )
+
+ for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
+ button.click(
+ fn=apply_styles,
+ _js=js_func,
+ inputs=[prompt, negative_prompt, styles],
+ outputs=[prompt, negative_prompt, styles],
+ )
+
+ token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+ negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+
+ ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
+
+ img2img_paste_fields = [
+ (img2img_prompt, "Prompt"),
+ (img2img_negative_prompt, "Negative prompt"),
+ (steps, "Steps"),
+ (sampler_index, "Sampler"),
+ (restore_faces, "Face restoration"),
+ (cfg_scale, "CFG scale"),
+ (image_cfg_scale, "Image CFG scale"),
+ (seed, "Seed"),
+ (width, "Size-1"),
+ (height, "Size-2"),
+ (batch_size, "Batch size"),
+ (subseed, "Variation seed"),
+ (subseed_strength, "Variation seed strength"),
+ (seed_resize_from_w, "Seed resize from-1"),
+ (seed_resize_from_h, "Seed resize from-2"),
+ (denoising_strength, "Denoising strength"),
+ (mask_blur, "Mask blur"),
+ *modules.scripts.scripts_img2img.infotext_fields
+ ]
+ parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
+ parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
+ ))
+
+ modules.scripts.scripts_current = None
+
+ with gr.Blocks(analytics_enabled=False) as extras_interface:
+ ui_postprocessing.create_ui()
+
+ with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
+ with gr.Row().style(equal_height=False):
+ with gr.Column(variant='panel'):
+ image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
+
+ with gr.Column(variant='panel'):
+ html = gr.HTML()
+ generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
+ html2 = gr.HTML()
+ with gr.Row():
+ buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
+
+ for tabname, button in buttons.items():
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
+ ))
+
+ image.change(
+ fn=wrap_gradio_call(modules.extras.run_pnginfo),
+ inputs=[image],
+ outputs=[html, generation_info, html2],
+ )
+
+ def update_interp_description(value):
+ interp_description_css = "Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]
")
+ with FormRow():
+ train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
+ create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
+
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
+ create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
+
+ with FormRow():
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
+
+ with FormRow():
+ clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
+ clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
+
+ with FormRow():
+ batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
+ gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
+
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
+
+ with FormRow():
+ template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
+ create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
+
+ training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
+ training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
+ varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
+ steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
+
+ with FormRow():
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
+
+ use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight")
+
+ save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
+ preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
+
+ shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
+ tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
+
+ latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
+
+ with gr.Row():
+ train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
+ interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
+
+ params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
+
+ script_callbacks.ui_train_tabs_callback(params)
+
+ with gr.Column(elem_id='ti_gallery_container'):
+ ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
+ ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
+ ti_progress = gr.HTML(elem_id="ti_progress", value="")
+ ti_outcome = gr.HTML(elem_id="ti_error", value="")
+
+ create_embedding.click(
+ fn=modules.textual_inversion.ui.create_embedding,
+ inputs=[
+ new_embedding_name,
+ initialization_text,
+ nvpt,
+ overwrite_old_embedding,
+ ],
+ outputs=[
+ train_embedding_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ create_hypernetwork.click(
+ fn=modules.hypernetworks.ui.create_hypernetwork,
+ inputs=[
+ new_hypernetwork_name,
+ new_hypernetwork_sizes,
+ overwrite_old_hypernetwork,
+ new_hypernetwork_layer_structure,
+ new_hypernetwork_activation_func,
+ new_hypernetwork_initialization_option,
+ new_hypernetwork_add_layer_norm,
+ new_hypernetwork_use_dropout,
+ new_hypernetwork_dropout_structure
+ ],
+ outputs=[
+ train_hypernetwork_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ run_preprocess.click(
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ dummy_component,
+ process_src,
+ process_dst,
+ process_width,
+ process_height,
+ preprocess_txt_action,
+ process_flip,
+ process_split,
+ process_caption,
+ process_caption_deepbooru,
+ process_split_threshold,
+ process_overlap_ratio,
+ process_focal_crop,
+ process_focal_crop_face_weight,
+ process_focal_crop_entropy_weight,
+ process_focal_crop_edges_weight,
+ process_focal_crop_debug,
+ process_multicrop,
+ process_multicrop_mindim,
+ process_multicrop_maxdim,
+ process_multicrop_minarea,
+ process_multicrop_maxarea,
+ process_multicrop_objective,
+ process_multicrop_threshold,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ],
+ )
+
+ train_embedding.click(
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ dummy_component,
+ train_embedding_name,
+ embedding_learn_rate,
+ batch_size,
+ gradient_step,
+ dataset_directory,
+ log_directory,
+ training_width,
+ training_height,
+ varsize,
+ steps,
+ clip_grad_mode,
+ clip_grad_value,
+ shuffle_tags,
+ tag_drop_out,
+ latent_sampling_method,
+ use_weight,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ save_image_with_stored_embedding,
+ preview_from_txt2img,
+ *txt2img_preview_params,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ train_hypernetwork.click(
+ fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ dummy_component,
+ train_hypernetwork_name,
+ hypernetwork_learn_rate,
+ batch_size,
+ gradient_step,
+ dataset_directory,
+ log_directory,
+ training_width,
+ training_height,
+ varsize,
+ steps,
+ clip_grad_mode,
+ clip_grad_value,
+ shuffle_tags,
+ tag_drop_out,
+ latent_sampling_method,
+ use_weight,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ preview_from_txt2img,
+ *txt2img_preview_params,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ interrupt_training.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
+ interrupt_preprocessing.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
+ def create_setting_component(key, is_quicksettings=False):
+ def fun():
+ return opts.data[key] if key in opts.data else opts.data_labels[key].default
+
+ info = opts.data_labels[key]
+ t = type(info.default)
+
+ args = info.component_args() if callable(info.component_args) else info.component_args
+
+ if info.component is not None:
+ comp = info.component
+ elif t == str:
+ comp = gr.Textbox
+ elif t == int:
+ comp = gr.Number
+ elif t == bool:
+ comp = gr.Checkbox
+ else:
+ raise Exception(f'bad options item type: {str(t)} for key {key}')
+
+ elem_id = "setting_"+key
+
+ if info.refresh is not None:
+ if is_quicksettings:
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+ create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
+ else:
+ with FormRow():
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+ create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
+ else:
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+
+ return res
+
+ components = []
+ component_dict = {}
+ shared.settings_components = component_dict
+
+ script_callbacks.ui_settings_callback()
+ opts.reorder()
+
+ def run_settings(*args):
+ changed = []
+
+ for key, value, comp in zip(opts.data_labels.keys(), args, components):
+ assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
+
+ for key, value, comp in zip(opts.data_labels.keys(), args, components):
+ if comp == dummy_component:
+ continue
+
+ if opts.set(key, value):
+ changed.append(key)
+
+ try:
+ opts.save(shared.config_filename)
+ except RuntimeError:
+ return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
+ return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
+
+ def run_settings_single(value, key):
+ if not opts.same_type(value, opts.data_labels[key].default):
+ return gr.update(visible=True), opts.dumpjson()
+
+ if not opts.set(key, value):
+ return gr.update(value=getattr(opts, key)), opts.dumpjson()
+
+ opts.save(shared.config_filename)
+
+ return get_value_for_setting(key), opts.dumpjson()
+
+ with gr.Blocks(analytics_enabled=False) as settings_interface:
+ with gr.Row():
+ with gr.Column(scale=6):
+ settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
+ with gr.Column():
+ restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
+
+ result = gr.HTML(elem_id="settings_result")
+
+ quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
+ quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
+
+ quicksettings_list = []
+
+ previous_section = None
+ current_tab = None
+ current_row = None
+ with gr.Tabs(elem_id="settings"):
+ for i, (k, item) in enumerate(opts.data_labels.items()):
+ section_must_be_skipped = item.section[0] is None
+
+ if previous_section != item.section and not section_must_be_skipped:
+ elem_id, text = item.section
+
+ if current_tab is not None:
+ current_row.__exit__()
+ current_tab.__exit__()
+
+ gr.Group()
+ current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
+ current_tab.__enter__()
+ current_row = gr.Column(variant='compact')
+ current_row.__enter__()
+
+ previous_section = item.section
+
+ if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
+ quicksettings_list.append((i, k, item))
+ components.append(dummy_component)
+ elif section_must_be_skipped:
+ components.append(dummy_component)
+ else:
+ component = create_setting_component(k)
+ component_dict[k] = component
+ components.append(component)
+
+ if current_tab is not None:
+ current_row.__exit__()
+ current_tab.__exit__()
+
+ with gr.TabItem("Actions"):
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+
+ with gr.TabItem("Licenses"):
+ gr.HTML(shared.html("licenses.html"), elem_id="licenses")
+
+ gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+
+ request_notifications.click(
+ fn=lambda: None,
+ inputs=[],
+ outputs=[],
+ _js='function(){}'
+ )
+
+ download_localization.click(
+ fn=lambda: None,
+ inputs=[],
+ outputs=[],
+ _js='download_localization'
+ )
+
+ def reload_scripts():
+ modules.scripts.reload_script_body_only()
+ reload_javascript() # need to refresh the html page
+
+ reload_script_bodies.click(
+ fn=reload_scripts,
+ inputs=[],
+ outputs=[]
+ )
+
+ def request_restart():
+ shared.state.interrupt()
+ shared.state.need_restart = True
+
+ restart_gradio.click(
+ fn=request_restart,
+ _js='restart_reload',
+ inputs=[],
+ outputs=[],
+ )
+
+ interfaces = [
+ (txt2img_interface, "txt2img", "txt2img"),
+ (img2img_interface, "img2img", "img2img"),
+ (extras_interface, "Extras", "extras"),
+ (pnginfo_interface, "PNG Info", "pnginfo"),
+ (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
+ (train_interface, "Train", "ti"),
+ ]
+
+ css = ""
+
+ for cssfile in modules.scripts.list_files_with_name("style.css"):
+ if not os.path.isfile(cssfile):
+ continue
+
+ with open(cssfile, "r", encoding="utf8") as file:
+ css += file.read() + "\n"
+
+ if os.path.exists(os.path.join(data_path, "user.css")):
+ with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file:
+ css += file.read() + "\n"
+
+ if not cmd_opts.no_progressbar_hiding:
+ css += css_hide_progressbar
+
+ interfaces += script_callbacks.ui_tabs_callback()
+ interfaces += [(settings_interface, "Settings", "settings")]
+
+ extensions_interface = ui_extensions.create_ui()
+ interfaces += [(extensions_interface, "Extensions", "extensions")]
+
+ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
+ with gr.Row(elem_id="quicksettings", variant="compact"):
+ for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
+ component = create_setting_component(k, is_quicksettings=True)
+ component_dict[k] = component
+
+ parameters_copypaste.connect_paste_params_buttons()
+
+ with gr.Tabs(elem_id="tabs") as tabs:
+ for interface, label, ifid in interfaces:
+ with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
+ interface.render()
+
+ if os.path.exists(os.path.join(script_path, "notification.mp3")):
+ audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
+
+ footer = shared.html("footer.html")
+ footer = footer.format(versions=versions_html())
+ gr.HTML(footer, elem_id="footer")
+
+ text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
+ settings_submit.click(
+ fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
+ inputs=components,
+ outputs=[text_settings, result],
+ )
+
+ for i, k, item in quicksettings_list:
+ component = component_dict[k]
+
+ component.change(
+ fn=lambda value, k=k: run_settings_single(value, key=k),
+ inputs=[component],
+ outputs=[component, text_settings],
+ )
+
+ text_settings.change(
+ fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
+ inputs=[],
+ outputs=[image_cfg_scale],
+ )
+
+ button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
+ button_set_checkpoint.click(
+ fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
+ _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
+ inputs=[component_dict['sd_model_checkpoint'], dummy_component],
+ outputs=[component_dict['sd_model_checkpoint'], text_settings],
+ )
+
+ component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
+
+ def get_settings_values():
+ return [get_value_for_setting(key) for key in component_keys]
+
+ demo.load(
+ fn=get_settings_values,
+ inputs=[],
+ outputs=[component_dict[k] for k in component_keys],
+ )
+
+ def modelmerger(*args):
+ try:
+ results = modules.extras.run_modelmerger(*args)
+ except Exception as e:
+ print("Error loading/saving model file:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ modules.sd_models.list_models() # to remove the potentially missing models from the list
+ return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
+ return results
+
+ modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
+ modelmerger_merge.click(
+ fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
+ _js='modelmerger',
+ inputs=[
+ dummy_component,
+ primary_model_name,
+ secondary_model_name,
+ tertiary_model_name,
+ interp_method,
+ interp_amount,
+ save_as_half,
+ custom_name,
+ checkpoint_format,
+ config_source,
+ bake_in_vae,
+ discard_weights,
+ ],
+ outputs=[
+ primary_model_name,
+ secondary_model_name,
+ tertiary_model_name,
+ component_dict['sd_model_checkpoint'],
+ modelmerger_result,
+ ]
+ )
+
+ ui_config_file = cmd_opts.ui_config_file
+ ui_settings = {}
+ settings_count = len(ui_settings)
+ error_loading = False
+
+ try:
+ if os.path.exists(ui_config_file):
+ with open(ui_config_file, "r", encoding="utf8") as file:
+ ui_settings = json.load(file)
+ except Exception:
+ error_loading = True
+ print("Error loading settings:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def loadsave(path, x):
+ def apply_field(obj, field, condition=None, init_field=None):
+ key = path + "/" + field
+
+ if getattr(obj, 'custom_script_source', None) is not None:
+ key = 'customscript/' + obj.custom_script_source + '/' + key
+
+ if getattr(obj, 'do_not_save_to_config', False):
+ return
+
+ saved_value = ui_settings.get(key, None)
+ if saved_value is None:
+ ui_settings[key] = getattr(obj, field)
+ elif condition and not condition(saved_value):
+ pass
+
+ # this warning is generally not useful;
+ # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
+ else:
+ setattr(obj, field, saved_value)
+ if init_field is not None:
+ init_field(saved_value)
+
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
+ apply_field(x, 'visible')
+
+ if type(x) == gr.Slider:
+ apply_field(x, 'value')
+ apply_field(x, 'minimum')
+ apply_field(x, 'maximum')
+ apply_field(x, 'step')
+
+ if type(x) == gr.Radio:
+ apply_field(x, 'value', lambda val: val in x.choices)
+
+ if type(x) == gr.Checkbox:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Textbox:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Number:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Dropdown:
+ def check_dropdown(val):
+ if getattr(x, 'multiselect', False):
+ return all([value in x.choices for value in val])
+ else:
+ return val in x.choices
+
+ apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
+
+ visit(txt2img_interface, loadsave, "txt2img")
+ visit(img2img_interface, loadsave, "img2img")
+ visit(extras_interface, loadsave, "extras")
+ visit(modelmerger_interface, loadsave, "modelmerger")
+ visit(train_interface, loadsave, "train")
+
+ if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
+ with open(ui_config_file, "w", encoding="utf8") as file:
+ json.dump(ui_settings, file, indent=4)
+
+ # Required as a workaround for change() event not triggering when loading values from ui-config.json
+ interp_description.value = update_interp_description(interp_method.value)
+
+ return demo
+
+
+def reload_javascript():
+ head = f'\n'
+
+ inline = f"{localization.localization_js(shared.opts.localization)};"
+ if cmd_opts.theme is not None:
+ inline += f"set_theme('{cmd_opts.theme}');"
+
+ for script in modules.scripts.list_scripts("javascript", ".js"):
+ head += f'\n'
+
+ head += f'\n'
+
+ def template_response(*args, **kwargs):
+ res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
+ res.body = res.body.replace(b'', f'{head}'.encode("utf8"))
+ res.init_headers()
+ return res
+
+ gradio.routes.templates.TemplateResponse = template_response
+
+
+if not hasattr(shared, 'GradioTemplateResponseOriginal'):
+ shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
+
+
+def versions_html():
+ import torch
+ import launch
+
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
+ commit = launch.commit_hash()
+ short_commit = commit[0:8]
+
+ if shared.xformers_available:
+ import xformers
+ xformers_version = xformers.__version__
+ else:
+ xformers_version = "N/A"
+
+ return f"""
+python: "),
+ "local_preview": path + ".png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return [shared.cmd_opts.hypernetwork_dir]
+
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..47491b2f09ccc960e2e237097f9d9e78075d25c0
--- /dev/null
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -0,0 +1,34 @@
+import json
+import os
+
+from modules import ui_extra_networks, sd_hijack
+
+
+class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Textual Inversion')
+ self.allow_negative_prompt = True
+
+ def refresh(self):
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+
+ def list_items(self):
+ for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
+ path, ext = os.path.splitext(embedding.filename)
+ preview_file = path + ".preview.png"
+
+ preview = None
+ if os.path.isfile(preview_file):
+ preview = self.link_preview(preview_file)
+
+ yield {
+ "name": embedding.name,
+ "filename": embedding.filename,
+ "preview": preview,
+ "search_term": self.search_terms_from_path(embedding.filename),
+ "prompt": json.dumps(embedding.name),
+ "local_preview": path + ".preview.png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..7789347028ecb309607038d0bc79eff934f45711
--- /dev/null
+++ b/modules/ui_postprocessing.py
@@ -0,0 +1,57 @@
+import gradio as gr
+from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
+import modules.generation_parameters_copypaste as parameters_copypaste
+
+
+def create_ui():
+ tab_index = gr.State(value=0)
+
+ with gr.Row().style(equal_height=False, variant='compact'):
+ with gr.Column(variant='compact'):
+ with gr.Tabs(elem_id="mode_extras"):
+ with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
+ extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
+
+ with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
+ image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
+
+ with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
+ show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
+
+ submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
+
+ script_inputs = scripts.scripts_postproc.setup_ui()
+
+ with gr.Column():
+ result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
+
+ tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
+ tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
+ tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
+
+ submit.click(
+ fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
+ inputs=[
+ tab_index,
+ extras_image,
+ image_batch,
+ extras_batch_input_dir,
+ extras_batch_output_dir,
+ show_extras_results,
+ *script_inputs
+ ],
+ outputs=[
+ result_images,
+ html_info_x,
+ html_info,
+ ]
+ )
+
+ parameters_copypaste.add_paste_fields("extras", extras_image, None)
+
+ extras_image.change(
+ fn=scripts.scripts_postproc.image_changed,
+ inputs=[], outputs=[]
+ )
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
new file mode 100644
index 0000000000000000000000000000000000000000..126f73a21d71070887fd094beaf0fe6d7e12df9c
--- /dev/null
+++ b/modules/ui_tempdir.py
@@ -0,0 +1,82 @@
+import os
+import tempfile
+from collections import namedtuple
+from pathlib import Path
+
+import gradio as gr
+
+from PIL import PngImagePlugin
+
+from modules import shared
+
+
+Savedfile = namedtuple("Savedfile", ["name"])
+
+
+def register_tmp_file(gradio, filename):
+ if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
+ gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
+
+ if hasattr(gradio, 'temp_dirs'): # gradio 3.9
+ gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
+
+
+def check_tmp_file(gradio, filename):
+ if hasattr(gradio, 'temp_file_sets'):
+ return any([filename in fileset for fileset in gradio.temp_file_sets])
+
+ if hasattr(gradio, 'temp_dirs'):
+ return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
+
+ return False
+
+
+def save_pil_to_file(pil_image, dir=None):
+ already_saved_as = getattr(pil_image, 'already_saved_as', None)
+ if already_saved_as and os.path.isfile(already_saved_as):
+ register_tmp_file(shared.demo, already_saved_as)
+
+ file_obj = Savedfile(already_saved_as)
+ return file_obj
+
+ if shared.opts.temp_dir != "":
+ dir = shared.opts.temp_dir
+
+ use_metadata = False
+ metadata = PngImagePlugin.PngInfo()
+ for key, value in pil_image.info.items():
+ if isinstance(key, str) and isinstance(value, str):
+ metadata.add_text(key, value)
+ use_metadata = True
+
+ file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
+ pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
+ return file_obj
+
+
+# override save to file function so that it also writes PNG info
+gr.processing_utils.save_pil_to_file = save_pil_to_file
+
+
+def on_tmpdir_changed():
+ if shared.opts.temp_dir == "" or shared.demo is None:
+ return
+
+ os.makedirs(shared.opts.temp_dir, exist_ok=True)
+
+ register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
+
+
+def cleanup_tmpdr():
+ temp_dir = shared.opts.temp_dir
+ if temp_dir == "" or not os.path.isdir(temp_dir):
+ return
+
+ for root, dirs, files in os.walk(temp_dir, topdown=False):
+ for name in files:
+ _, extension = os.path.splitext(name)
+ if extension != ".png":
+ continue
+
+ filename = os.path.join(root, name)
+ os.remove(filename)
diff --git a/modules/upscaler.py b/modules/upscaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2eaa7308af0091b6e8f407e889b2e446679e149
--- /dev/null
+++ b/modules/upscaler.py
@@ -0,0 +1,145 @@
+import os
+from abc import abstractmethod
+
+import PIL
+import numpy as np
+import torch
+from PIL import Image
+
+import modules.shared
+from modules import modelloader, shared
+
+LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
+
+
+class Upscaler:
+ name = None
+ model_path = None
+ model_name = None
+ model_url = None
+ enable = True
+ filter = None
+ model = None
+ user_path = None
+ scalers: []
+ tile = True
+
+ def __init__(self, create_dirs=False):
+ self.mod_pad_h = None
+ self.tile_size = modules.shared.opts.ESRGAN_tile
+ self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
+ self.device = modules.shared.device
+ self.img = None
+ self.output = None
+ self.scale = 1
+ self.half = not modules.shared.cmd_opts.no_half
+ self.pre_pad = 0
+ self.mod_scale = None
+
+ if self.model_path is None and self.name:
+ self.model_path = os.path.join(shared.models_path, self.name)
+ if self.model_path and create_dirs:
+ os.makedirs(self.model_path, exist_ok=True)
+
+ try:
+ import cv2
+ self.can_tile = True
+ except:
+ pass
+
+ @abstractmethod
+ def do_upscale(self, img: PIL.Image, selected_model: str):
+ return img
+
+ def upscale(self, img: PIL.Image, scale, selected_model: str = None):
+ self.scale = scale
+ dest_w = int(img.width * scale)
+ dest_h = int(img.height * scale)
+
+ for i in range(3):
+ shape = (img.width, img.height)
+
+ img = self.do_upscale(img, selected_model)
+
+ if shape == (img.width, img.height):
+ break
+
+ if img.width >= dest_w and img.height >= dest_h:
+ break
+
+ if img.width != dest_w or img.height != dest_h:
+ img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
+
+ return img
+
+ @abstractmethod
+ def load_model(self, path: str):
+ pass
+
+ def find_models(self, ext_filter=None) -> list:
+ return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
+
+ def update_status(self, prompt):
+ print(f"\nextras: {prompt}", file=shared.progress_print_out)
+
+
+class UpscalerData:
+ name = None
+ data_path = None
+ scale: int = 4
+ scaler: Upscaler = None
+ model: None
+
+ def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
+ self.name = name
+ self.data_path = path
+ self.local_data_path = path
+ self.scaler = upscaler
+ self.scale = scale
+ self.model = model
+
+
+class UpscalerNone(Upscaler):
+ name = "None"
+ scalers = []
+
+ def load_model(self, path):
+ pass
+
+ def do_upscale(self, img, selected_model=None):
+ return img
+
+ def __init__(self, dirname=None):
+ super().__init__(False)
+ self.scalers = [UpscalerData("None", None, self)]
+
+
+class UpscalerLanczos(Upscaler):
+ scalers = []
+
+ def do_upscale(self, img, selected_model=None):
+ return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
+
+ def load_model(self, _):
+ pass
+
+ def __init__(self, dirname=None):
+ super().__init__(False)
+ self.name = "Lanczos"
+ self.scalers = [UpscalerData("Lanczos", None, self)]
+
+
+class UpscalerNearest(Upscaler):
+ scalers = []
+
+ def do_upscale(self, img, selected_model=None):
+ return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
+
+ def load_model(self, _):
+ pass
+
+ def __init__(self, dirname=None):
+ super().__init__(False)
+ self.name = "Nearest"
+ self.scalers = [UpscalerData("Nearest", None, self)]
diff --git a/modules/xlmr.py b/modules/xlmr.py
new file mode 100644
index 0000000000000000000000000000000000000000..beab3fdf55e7bcffd96f3b36679e7a90c0f390dc
--- /dev/null
+++ b/modules/xlmr.py
@@ -0,0 +1,137 @@
+from transformers import BertPreTrainedModel,BertModel,BertConfig
+import torch.nn as nn
+import torch
+from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
+from transformers import XLMRobertaModel,XLMRobertaTokenizer
+from typing import Optional
+
+class BertSeriesConfig(BertConfig):
+ def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
+
+ super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+class RobertaSeriesConfig(XLMRobertaConfig):
+ def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+
+class BertSeriesModelWithTransformation(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+ config_class = BertSeriesConfig
+
+ def __init__(self, config=None, **kargs):
+ # modify initialization for autoloading
+ if config is None:
+ config = XLMRobertaConfig()
+ config.attention_probs_dropout_prob= 0.1
+ config.bos_token_id=0
+ config.eos_token_id=2
+ config.hidden_act='gelu'
+ config.hidden_dropout_prob=0.1
+ config.hidden_size=1024
+ config.initializer_range=0.02
+ config.intermediate_size=4096
+ config.layer_norm_eps=1e-05
+ config.max_position_embeddings=514
+
+ config.num_attention_heads=16
+ config.num_hidden_layers=24
+ config.output_past=True
+ config.pad_token_id=1
+ config.position_embedding_type= "absolute"
+
+ config.type_vocab_size= 1
+ config.use_cache=True
+ config.vocab_size= 250002
+ config.project_dim = 768
+ config.learn_encoder = False
+ super().__init__(config)
+ self.roberta = XLMRobertaModel(config)
+ self.transformation = nn.Linear(config.hidden_size,config.project_dim)
+ self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
+ self.pooler = lambda x: x[:,0]
+ self.post_init()
+
+ def encode(self,c):
+ device = next(self.parameters()).device
+ text = self.tokenizer(c,
+ truncation=True,
+ max_length=77,
+ return_length=False,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt")
+ text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
+ text["attention_mask"] = torch.tensor(
+ text['attention_mask']).to(device)
+ features = self(**text)
+ return features['projection_state']
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) :
+ r"""
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+
+ outputs = self.roberta(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+
+ # last module outputs
+ sequence_output = outputs[0]
+
+
+ # project every module
+ sequence_output_ln = self.pre_LN(sequence_output)
+
+ # pooler
+ pooler_output = self.pooler(sequence_output_ln)
+ pooler_output = self.transformation(pooler_output)
+ projection_state = self.transformation(outputs.last_hidden_state)
+
+ return {
+ 'pooler_output':pooler_output,
+ 'last_hidden_state':outputs.last_hidden_state,
+ 'hidden_states':outputs.hidden_states,
+ 'attentions':outputs.attentions,
+ 'projection_state':projection_state,
+ 'sequence_out': sequence_output
+ }
+
+
+class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
+ base_model_prefix = 'roberta'
+ config_class= RobertaSeriesConfig
\ No newline at end of file
diff --git a/play.py b/play.py
new file mode 100644
index 0000000000000000000000000000000000000000..299634c682edde551685b6a2221279450837e793
--- /dev/null
+++ b/play.py
@@ -0,0 +1,20 @@
+import requests
+import random
+import time
+import base64
+import hashlib
+import json
+
+def lightning():
+ start = int(time.time())
+ url = "https://wsoqr-01gwy9mc1gzh3b4ce9b708vp31.litng-ai-03.litng.ai/predict"
+ form = {
+ "prompt": "extremely detailed CG unity 8k wallpaper, masterpiece, best quality, ultra-detailed, best illustration, best shadow, photorealistic:1.4, 1 gorgeous girls,oversize pink_hoodie,under eiffel tower,grey_hair:1.1, collarbone,puffy breasts:1.5,full body shot,shiny eyes,enjoyable expression,evil smile,slim legs,narrow waist,detailed face, looking at viewer,looking back,gorgeous skin,short curly hair,kneeling,puffy ass up,climbing,lying,rosy pussy,nsfw,insert left_hand into pussy",
+ }
+ resp = requests.post(url, json=form)
+ resp_data = json.loads(resp.content)
+ print(resp.status_code, '\n', resp_data)
+ print("time cost(ms): ", int(time.time())*1e3-start*1e3)
+
+
+lightning()
\ No newline at end of file
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d60585a63ad8c0cf6416e98c23cc4303b97692d
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,155 @@
+# Prediction interface for Cog ⚙️
+# https://github.com/replicate/cog/blob/main/docs/python.md
+
+from cog import BasePredictor, Input, Path
+
+import os
+import sys
+import signal
+import time
+import re
+from typing import Dict, List, Any
+
+import logging
+logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
+
+from modules import errors
+from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
+
+import torch
+
+# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
+if ".dev" in torch.__version__ or "+git" in torch.__version__:
+ torch.__long_version__ = torch.__version__
+ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
+
+from modules import shared, devices, ui_tempdir
+from modules.api.api import encode_pil_to_base64
+import modules.codeformer_model as codeformer
+import modules.face_restoration
+import modules.gfpgan_model as gfpgan
+import modules.img2img
+
+import modules.lowvram
+import modules.paths
+import modules.scripts
+import modules.sd_hijack
+import modules.sd_models
+import modules.sd_vae
+import modules.txt2img
+import modules.script_callbacks
+import modules.textual_inversion.textual_inversion
+import modules.progress
+
+import modules.ui
+from modules import modelloader
+from modules.shared import cmd_opts, opts
+import modules.hypernetworks.hypernetwork
+
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+
+
+def initialize():
+ modelloader.cleanup_models()
+ modules.sd_models.setup_model()
+ codeformer.setup_model(cmd_opts.codeformer_models_path)
+ gfpgan.setup_model(cmd_opts.gfpgan_models_path)
+
+ modelloader.list_builtin_upscalers()
+ modelloader.load_upscalers()
+
+ modules.sd_vae.refresh_vae_list()
+
+
+ try:
+ modules.sd_models.load_model()
+ except Exception as e:
+ errors.display(e, "loading stable diffusion model")
+ print("", file=sys.stderr)
+ print("Stable diffusion model failed to load, exiting", file=sys.stderr)
+ exit(1)
+
+ shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
+ shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+
+ # make the program just exit at ctrl+c without waiting for anything
+ # def sigint_handler(sig, frame):
+ # print(f'Interrupted with signal {sig} in {frame}')
+ # os._exit(0)
+
+ # signal.signal(signal.SIGINT, sigint_handler)
+
+class Predictor(BasePredictor):
+ def setup(self):
+ """Load the model into memory to make running multiple predictions efficient"""
+ initialize()
+
+ def predict(
+ self,
+ prompt: str = Input(description="prompt en", default="lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer"),
+ negative_prompt: str = Input(description="negative prompt", default="paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), (tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, 3hands,4fingers,3arms, bad anatomy, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts,poorly drawn face,mutation,deformed"),
+ sampler_name: str = Input(description="sampler name", default="DPM++ SDE Karras", choices=["DPM++ SDE Karras", "DPM++ 2M Karras", "DPM++ 2S a Karras", "DPM2 a Karras", "DPM2 Karras", "LMS Karras", "DPM adaptive", "DPM fast", "DPM++ SDE", "DPM++ 2M", "DPM++ 2S a", "DPM2 a", "DPM2", "Heun", "LMS", "Euler", "Euler a"]),
+ steps: int = Input(description="steps", default=20),
+ cfg_scale: int = Input(description="cfg scale", default=8),
+ width: int = Input(description="width", default=512),
+ height: int = Input(description="height", default=768),
+ seed: int = Input(description="seed", default=-1),
+ ) -> Path:
+ """Run a single prediction on the model"""
+ args = {
+ "do_not_save_samples": True,
+ "do_not_save_grid": True,
+ "outpath_samples": "./output",
+ "prompt": prompt,
+ "negative_prompt": negative_prompt,
+ "sampler_name": sampler_name,
+ "steps": steps, # 25
+ "cfg_scale": cfg_scale,
+ "width": width,
+ "height": height,
+ "seed": seed,
+ }
+ p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
+ processed = process_images(p)
+ filename = str(int(time.time())) + ".png"
+ processed.images[0].save(fp=filename, format="PNG")
+ # single_image_b64 = encode_pil_to_base64(processed.images[0]).decode('utf-8')
+ return Path(filename)
+
+class PredictorOld(BasePredictor):
+ def setup(self):
+ """Load the model into memory to make running multiple predictions efficient"""
+ initialize()
+ self.shared = shared
+
+ def predict(
+ self,
+ prompt: str = Input(description="prompt en"),
+ ) -> Dict[str, Any]:
+ """Run a single prediction on the model"""
+ args = {
+ "do_not_save_samples": True,
+ "do_not_save_grid": True,
+ "outpath_samples": "./output",
+ "prompt": "lora:koreanDollLikeness_v15:0.66, best quality, ultra high res, (photorealistic:1.4), 1girl, beige sweater, black choker, smile, laughing, bare shoulders, solo focus, ((full body), (brown hair:1), looking at viewer",
+ "negative_prompt": "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans, (ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), (tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, 3hands,4fingers,3arms, bad anatomy, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts,poorly drawn face,mutation,deformed",
+ "sampler_name": "DPM++ SDE Karras",
+ "steps": 20, # 25
+ "cfg_scale": 8,
+ "width": 512,
+ "height": 768,
+ "seed": -1,
+ }
+ if len(prompt) > 0:
+ print("get prompt from request: ", prompt)
+ args["prompt"] = prompt
+ p = StableDiffusionProcessingTxt2Img(sd_model=self.shared.sd_model, **args)
+ processed = process_images(p)
+ single_image_b64 = encode_pil_to_base64(processed.images[0]).decode('utf-8')
+ return {
+ "img_data": single_image_b64,
+ "parameters": processed.images[0].info.get('parameters', ""),
+ }
diff --git a/repositories/BLIP/CODEOWNERS b/repositories/BLIP/CODEOWNERS
new file mode 100644
index 0000000000000000000000000000000000000000..522fa4a0f715cd0328b9b9dbacae00e060193f43
--- /dev/null
+++ b/repositories/BLIP/CODEOWNERS
@@ -0,0 +1,2 @@
+# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
+#ECCN:Open Source
diff --git a/repositories/BLIP/CODE_OF_CONDUCT.md b/repositories/BLIP/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..b6724718c9512d730bb7f1bcc5848cd420241407
--- /dev/null
+++ b/repositories/BLIP/CODE_OF_CONDUCT.md
@@ -0,0 +1,105 @@
+# Salesforce Open Source Community Code of Conduct
+
+## About the Code of Conduct
+
+Equality is a core value at Salesforce. We believe a diverse and inclusive
+community fosters innovation and creativity, and are committed to building a
+culture where everyone feels included.
+
+Salesforce open-source projects are committed to providing a friendly, safe, and
+welcoming environment for all, regardless of gender identity and expression,
+sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
+race, age, religion, level of experience, education, socioeconomic status, or
+other similar personal characteristics.
+
+The goal of this code of conduct is to specify a baseline standard of behavior so
+that people with different social values and communication styles can work
+together effectively, productively, and respectfully in our open source community.
+It also establishes a mechanism for reporting issues and resolving conflicts.
+
+All questions and reports of abusive, harassing, or otherwise unacceptable behavior
+in a Salesforce open-source project may be reported by contacting the Salesforce
+Open Source Conduct Committee at ossconduct@salesforce.com.
+
+## Our Pledge
+
+In the interest of fostering an open and welcoming environment, we as
+contributors and maintainers pledge to making participation in our project and
+our community a harassment-free experience for everyone, regardless of gender
+identity and expression, sexual orientation, disability, physical appearance,
+body size, ethnicity, nationality, race, age, religion, level of experience, education,
+socioeconomic status, or other similar personal characteristics.
+
+## Our Standards
+
+Examples of behavior that contributes to creating a positive environment
+include:
+
+* Using welcoming and inclusive language
+* Being respectful of differing viewpoints and experiences
+* Gracefully accepting constructive criticism
+* Focusing on what is best for the community
+* Showing empathy toward other community members
+
+Examples of unacceptable behavior by participants include:
+
+* The use of sexualized language or imagery and unwelcome sexual attention or
+advances
+* Personal attacks, insulting/derogatory comments, or trolling
+* Public or private harassment
+* Publishing, or threatening to publish, others' private information—such as
+a physical or electronic address—without explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+professional setting
+* Advocating for or encouraging any of the above behaviors
+
+## Our Responsibilities
+
+Project maintainers are responsible for clarifying the standards of acceptable
+behavior and are expected to take appropriate and fair corrective action in
+response to any instances of unacceptable behavior.
+
+Project maintainers have the right and responsibility to remove, edit, or
+reject comments, commits, code, wiki edits, issues, and other contributions
+that are not aligned with this Code of Conduct, or to ban temporarily or
+permanently any contributor for other behaviors that they deem inappropriate,
+threatening, offensive, or harmful.
+
+## Scope
+
+This Code of Conduct applies both within project spaces and in public spaces
+when an individual is representing the project or its community. Examples of
+representing a project or community include using an official project email
+address, posting via an official social media account, or acting as an appointed
+representative at an online or offline event. Representation of a project may be
+further defined and clarified by project maintainers.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported by contacting the Salesforce Open Source Conduct Committee
+at ossconduct@salesforce.com. All complaints will be reviewed and investigated
+and will result in a response that is deemed necessary and appropriate to the
+circumstances. The committee is obligated to maintain confidentiality with
+regard to the reporter of an incident. Further details of specific enforcement
+policies may be posted separately.
+
+Project maintainers who do not follow or enforce the Code of Conduct in good
+faith may face temporary or permanent repercussions as determined by other
+members of the project's leadership and the Salesforce Open Source Conduct
+Committee.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
+version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
+It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
+[CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
+
+This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
+
+[contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
+[golang-coc]: https://golang.org/conduct
+[cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
+[microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
+[cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
diff --git a/repositories/BLIP/LICENSE.txt b/repositories/BLIP/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a63e87f4e1e90c96861648a16a7304d97d3c3f7b
--- /dev/null
+++ b/repositories/BLIP/LICENSE.txt
@@ -0,0 +1,12 @@
+Copyright (c) 2022, Salesforce.com, Inc.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+
+* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/repositories/BLIP/README.md b/repositories/BLIP/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7a86ebfc21536b97fde1789f0c9ec4a53d2bdd77
--- /dev/null
+++ b/repositories/BLIP/README.md
@@ -0,0 +1,114 @@
+## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
+
+
+
+This is the PyTorch code of the BLIP paper [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10.
+To install the dependencies, run pip install -r requirements.txt
+
+Catalog:
+- [x] Inference demo
+- [x] Pre-trained and finetuned checkpoints
+- [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2
+- [x] Pre-training code
+- [x] Zero-shot video-text retrieval
+- [x] Download of bootstrapped pre-training datasets
+
+
+### Inference demo:
+Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed).
+The demo includes code for:
+1. Image captioning
+2. Open-ended visual question answering
+3. Multimodal / unimodal feature extraction
+4. Image-text matching
+
+Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio).
+
+Replicate web demo and Docker image is also available at [![Replicate](https://replicate.com/salesforce/blip/badge)](https://replicate.com/salesforce/blip)
+
+### Pre-trained checkpoints:
+Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
+--- | :---: | :---: | :---:
+14M | Download| - | -
+129M | Download| Download | Download
+
+### Finetuned checkpoints:
+Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
+--- | :---: | :---: | :---:
+Image-Text Retrieval (COCO) | Download| - | Download
+Image-Text Retrieval (Flickr30k) | Download| - | Download
+Image Captioning (COCO) | - | Download| Download |
+VQA | Download| Download | -
+NLVR2 | Download| - | -
+
+
+### Image-Text Retrieval:
+1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly.
+2. To evaluate the finetuned BLIP model on COCO, run:
+python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
+--config ./configs/retrieval_coco.yaml \
+--output_dir output/retrieval_coco \
+--evaluate
+3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
+python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
+--config ./configs/retrieval_coco.yaml \
+--output_dir output/retrieval_coco
+
+### Image-Text Captioning:
+1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly.
+2. To evaluate the finetuned BLIP model on COCO, run:
+python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate
+3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server)
+python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py
+4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
+python -m torch.distributed.run --nproc_per_node=8 train_caption.py
+
+### VQA:
+1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml.
+2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server)
+python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate
+3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
+python -m torch.distributed.run --nproc_per_node=16 train_vqa.py
+
+### NLVR2:
+1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml.
+2. To evaluate the finetuned BLIP model, run
+python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate
+3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
+python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py
+
+### Finetune with ViT-L:
+In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). Gradient checkpoint can also be activated in the config file to reduce GPU memory usage.
+
+### Pre-train:
+1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
+2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files .
+3. Pre-train the model using 8 A100 GPUs:
+python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain
+
+### Zero-shot video-text retrieval:
+1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml.
+2. Install [decord](https://github.com/dmlc/decord) with pip install decord
+3. To perform zero-shot evaluation, run
+python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py
+
+### Pre-training datasets download:
+We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}.
+
+Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L
+--- | :---: | :---: | :---:
+CC3M+CC12M+SBU | Download| Download| Download
+LAION115M | Download| Download| Download
+
+### Citation
+If you find this code to be useful for your research, please consider citing.
+
+@inproceedings{li2022blip,
+ title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation},
+ author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
+ year={2022},
+ booktitle={ICML},
+}
+
+### Acknowledgement
+The implementation of BLIP relies on resources from ALBEF, Huggingface Transformers, and timm. We thank the original authors for their open-sourcing.
diff --git a/repositories/BLIP/SECURITY.md b/repositories/BLIP/SECURITY.md
new file mode 100644
index 0000000000000000000000000000000000000000..8249025739809035264e7776583b2f3ec100553c
--- /dev/null
+++ b/repositories/BLIP/SECURITY.md
@@ -0,0 +1,7 @@
+## Security
+
+Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
+as soon as it is discovered. This library limits its runtime dependencies in
+order to reduce the total cost of ownership as much as can be, but all consumers
+should remain vigilant and have their security stakeholders review all third-party
+products (3PP) like this one and their dependencies.
diff --git a/repositories/BLIP/cog.yaml b/repositories/BLIP/cog.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c1dfcc430a4cab0fdd2a60a682336219a61c4a4f
--- /dev/null
+++ b/repositories/BLIP/cog.yaml
@@ -0,0 +1,17 @@
+build:
+ gpu: true
+ cuda: "11.1"
+ python_version: "3.8"
+ system_packages:
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ python_packages:
+ - "ipython==7.30.1"
+ - "torchvision==0.11.1"
+ - "torch==1.10.0"
+ - "timm==0.4.12"
+ - "transformers==4.15.0"
+ - "fairscale==0.4.4"
+ - "pycocoevalcap==1.2"
+
+predict: "predict.py:Predictor"
diff --git a/repositories/BLIP/configs/bert_config.json b/repositories/BLIP/configs/bert_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..3ef38aabc7f966b53079e9d559dc59e459cc0051
--- /dev/null
+++ b/repositories/BLIP/configs/bert_config.json
@@ -0,0 +1,21 @@
+{
+ "architectures": [
+ "BertModel"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "type_vocab_size": 2,
+ "vocab_size": 30522,
+ "encoder_width": 768,
+ "add_cross_attention": true
+}
diff --git a/repositories/BLIP/configs/caption_coco.yaml b/repositories/BLIP/configs/caption_coco.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..42eab7030c0310ba2f265baf36fa1400aa6e5846
--- /dev/null
+++ b/repositories/BLIP/configs/caption_coco.yaml
@@ -0,0 +1,33 @@
+image_root: '/export/share/datasets/vision/coco/images/'
+ann_root: 'annotation'
+coco_gt_root: 'annotation/coco_gt'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
+
+# size of vit model; base or large
+vit: 'base'
+vit_grad_ckpt: False
+vit_ckpt_layer: 0
+batch_size: 32
+init_lr: 1e-5
+
+# vit: 'large'
+# vit_grad_ckpt: True
+# vit_ckpt_layer: 5
+# batch_size: 16
+# init_lr: 2e-6
+
+image_size: 384
+
+# generation configs
+max_length: 20
+min_length: 5
+num_beams: 3
+prompt: 'a picture of '
+
+# optimizer
+weight_decay: 0.05
+min_lr: 0
+max_epoch: 5
+
diff --git a/repositories/BLIP/configs/med_config.json b/repositories/BLIP/configs/med_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..0ffad0a6f3c2f9f11b8faa84529d9860bb70327a
--- /dev/null
+++ b/repositories/BLIP/configs/med_config.json
@@ -0,0 +1,21 @@
+{
+ "architectures": [
+ "BertModel"
+ ],
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "type_vocab_size": 2,
+ "vocab_size": 30524,
+ "encoder_width": 768,
+ "add_cross_attention": true
+}
diff --git a/repositories/BLIP/configs/nlvr.yaml b/repositories/BLIP/configs/nlvr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2d1122aadb1a776bd347068233096b0c984f648b
--- /dev/null
+++ b/repositories/BLIP/configs/nlvr.yaml
@@ -0,0 +1,21 @@
+image_root: '/export/share/datasets/vision/NLVR2/'
+ann_root: 'annotation'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
+
+#size of vit model; base or large
+vit: 'base'
+batch_size_train: 16
+batch_size_test: 64
+vit_grad_ckpt: False
+vit_ckpt_layer: 0
+max_epoch: 15
+
+image_size: 384
+
+# optimizer
+weight_decay: 0.05
+init_lr: 3e-5
+min_lr: 0
+
diff --git a/repositories/BLIP/configs/nocaps.yaml b/repositories/BLIP/configs/nocaps.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9028135859b94aef5324c85c80e376c609d8a089
--- /dev/null
+++ b/repositories/BLIP/configs/nocaps.yaml
@@ -0,0 +1,15 @@
+image_root: '/export/share/datasets/vision/nocaps/'
+ann_root: 'annotation'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
+
+vit: 'base'
+batch_size: 32
+
+image_size: 384
+
+max_length: 20
+min_length: 5
+num_beams: 3
+prompt: 'a picture of '
\ No newline at end of file
diff --git a/repositories/BLIP/configs/pretrain.yaml b/repositories/BLIP/configs/pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..02355ee0228932803c661616485bf315e862b826
--- /dev/null
+++ b/repositories/BLIP/configs/pretrain.yaml
@@ -0,0 +1,27 @@
+train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
+ '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
+ ]
+laion_path: ''
+
+# size of vit model; base or large
+vit: 'base'
+vit_grad_ckpt: False
+vit_ckpt_layer: 0
+
+image_size: 224
+batch_size: 75
+
+queue_size: 57600
+alpha: 0.4
+
+# optimizer
+weight_decay: 0.05
+init_lr: 3e-4
+min_lr: 1e-6
+warmup_lr: 1e-6
+lr_decay_rate: 0.9
+max_epoch: 20
+warmup_steps: 3000
+
+
+
diff --git a/repositories/BLIP/configs/retrieval_coco.yaml b/repositories/BLIP/configs/retrieval_coco.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a8569e9b67112fe3605ac25e4fdc0231f7975378
--- /dev/null
+++ b/repositories/BLIP/configs/retrieval_coco.yaml
@@ -0,0 +1,34 @@
+image_root: '/export/share/datasets/vision/coco/images/'
+ann_root: 'annotation'
+dataset: 'coco'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
+
+# size of vit model; base or large
+
+vit: 'base'
+batch_size_train: 32
+batch_size_test: 64
+vit_grad_ckpt: True
+vit_ckpt_layer: 4
+init_lr: 1e-5
+
+# vit: 'large'
+# batch_size_train: 16
+# batch_size_test: 32
+# vit_grad_ckpt: True
+# vit_ckpt_layer: 12
+# init_lr: 5e-6
+
+image_size: 384
+queue_size: 57600
+alpha: 0.4
+k_test: 256
+negative_all_rank: True
+
+# optimizer
+weight_decay: 0.05
+min_lr: 0
+max_epoch: 6
+
diff --git a/repositories/BLIP/configs/retrieval_flickr.yaml b/repositories/BLIP/configs/retrieval_flickr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d75ea4eed87c9a001523c5e5914998c5e737594d
--- /dev/null
+++ b/repositories/BLIP/configs/retrieval_flickr.yaml
@@ -0,0 +1,34 @@
+image_root: '/export/share/datasets/vision/flickr30k/'
+ann_root: 'annotation'
+dataset: 'flickr'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
+
+# size of vit model; base or large
+
+vit: 'base'
+batch_size_train: 32
+batch_size_test: 64
+vit_grad_ckpt: True
+vit_ckpt_layer: 4
+init_lr: 1e-5
+
+# vit: 'large'
+# batch_size_train: 16
+# batch_size_test: 32
+# vit_grad_ckpt: True
+# vit_ckpt_layer: 10
+# init_lr: 5e-6
+
+image_size: 384
+queue_size: 57600
+alpha: 0.4
+k_test: 128
+negative_all_rank: False
+
+# optimizer
+weight_decay: 0.05
+min_lr: 0
+max_epoch: 6
+
diff --git a/repositories/BLIP/configs/retrieval_msrvtt.yaml b/repositories/BLIP/configs/retrieval_msrvtt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..395f62542bb22d706b8e19e2455d2c7298984d0b
--- /dev/null
+++ b/repositories/BLIP/configs/retrieval_msrvtt.yaml
@@ -0,0 +1,12 @@
+video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
+ann_root: 'annotation'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
+
+# size of vit model; base or large
+vit: 'base'
+batch_size: 64
+k_test: 128
+image_size: 384
+num_frm_test: 8
\ No newline at end of file
diff --git a/repositories/BLIP/configs/vqa.yaml b/repositories/BLIP/configs/vqa.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..74327e6d0a34672023b44569558fe8beeb052548
--- /dev/null
+++ b/repositories/BLIP/configs/vqa.yaml
@@ -0,0 +1,25 @@
+vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
+vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
+train_files: ['vqa_train','vqa_val','vg_qa']
+ann_root: 'annotation'
+
+# set pretrained as a file path or an url
+pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
+
+# size of vit model; base or large
+vit: 'base'
+batch_size_train: 16
+batch_size_test: 32
+vit_grad_ckpt: False
+vit_ckpt_layer: 0
+init_lr: 2e-5
+
+image_size: 480
+
+k_test: 128
+inference: 'rank'
+
+# optimizer
+weight_decay: 0.05
+min_lr: 0
+max_epoch: 10
\ No newline at end of file
diff --git a/repositories/BLIP/data/__init__.py b/repositories/BLIP/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0be209acf415855ea6ef753efedf903b5decb6b9
--- /dev/null
+++ b/repositories/BLIP/data/__init__.py
@@ -0,0 +1,101 @@
+import torch
+from torch.utils.data import DataLoader
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+
+from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
+from data.nocaps_dataset import nocaps_eval
+from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
+from data.vqa_dataset import vqa_dataset
+from data.nlvr_dataset import nlvr_dataset
+from data.pretrain_dataset import pretrain_dataset
+from transform.randaugment import RandomAugment
+
+def create_dataset(dataset, config, min_scale=0.5):
+
+ normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+
+ transform_train = transforms.Compose([
+ transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
+ transforms.RandomHorizontalFlip(),
+ RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
+ transforms.ToTensor(),
+ normalize,
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ normalize,
+ ])
+
+ if dataset=='pretrain':
+ dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
+ return dataset
+
+ elif dataset=='caption_coco':
+ train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
+ val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
+ test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
+ return train_dataset, val_dataset, test_dataset
+
+ elif dataset=='nocaps':
+ val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
+ test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
+ return val_dataset, test_dataset
+
+ elif dataset=='retrieval_coco':
+ train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
+ val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
+ test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
+ return train_dataset, val_dataset, test_dataset
+
+ elif dataset=='retrieval_flickr':
+ train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
+ val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
+ test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
+ return train_dataset, val_dataset, test_dataset
+
+ elif dataset=='vqa':
+ train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
+ train_files = config['train_files'], split='train')
+ test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
+ return train_dataset, test_dataset
+
+ elif dataset=='nlvr':
+ train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
+ val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
+ test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
+ return train_dataset, val_dataset, test_dataset
+
+
+def create_sampler(datasets, shuffles, num_tasks, global_rank):
+ samplers = []
+ for dataset,shuffle in zip(datasets,shuffles):
+ sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
+ samplers.append(sampler)
+ return samplers
+
+
+def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
+ loaders = []
+ for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
+ if is_train:
+ shuffle = (sampler is None)
+ drop_last = True
+ else:
+ shuffle = False
+ drop_last = False
+ loader = DataLoader(
+ dataset,
+ batch_size=bs,
+ num_workers=n_worker,
+ pin_memory=True,
+ sampler=sampler,
+ shuffle=shuffle,
+ collate_fn=collate_fn,
+ drop_last=drop_last,
+ )
+ loaders.append(loader)
+ return loaders
+
diff --git a/repositories/BLIP/data/coco_karpathy_dataset.py b/repositories/BLIP/data/coco_karpathy_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a34d29205f42aa09695b160ac9c91958ba041bb3
--- /dev/null
+++ b/repositories/BLIP/data/coco_karpathy_dataset.py
@@ -0,0 +1,126 @@
+import os
+import json
+
+from torch.utils.data import Dataset
+from torchvision.datasets.utils import download_url
+
+from PIL import Image
+
+from data.utils import pre_caption
+
+class coco_karpathy_train(Dataset):
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
+ '''
+ image_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ '''
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
+ filename = 'coco_karpathy_train.json'
+
+ download_url(url,ann_root)
+
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
+ self.transform = transform
+ self.image_root = image_root
+ self.max_words = max_words
+ self.prompt = prompt
+
+ self.img_ids = {}
+ n = 0
+ for ann in self.annotation:
+ img_id = ann['image_id']
+ if img_id not in self.img_ids.keys():
+ self.img_ids[img_id] = n
+ n += 1
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.image_root,ann['image'])
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
+
+ return image, caption, self.img_ids[ann['image_id']]
+
+
+class coco_karpathy_caption_eval(Dataset):
+ def __init__(self, transform, image_root, ann_root, split):
+ '''
+ image_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ split (string): val or test
+ '''
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
+
+ download_url(urls[split],ann_root)
+
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
+ self.transform = transform
+ self.image_root = image_root
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.image_root,ann['image'])
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
+
+ return image, int(img_id)
+
+
+class coco_karpathy_retrieval_eval(Dataset):
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
+ '''
+ image_root (string): Root directory of images (e.g. coco/images/)
+ ann_root (string): directory to store the annotation file
+ split (string): val or test
+ '''
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
+
+ download_url(urls[split],ann_root)
+
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
+ self.transform = transform
+ self.image_root = image_root
+
+ self.text = []
+ self.image = []
+ self.txt2img = {}
+ self.img2txt = {}
+
+ txt_id = 0
+ for img_id, ann in enumerate(self.annotation):
+ self.image.append(ann['image'])
+ self.img2txt[img_id] = []
+ for i, caption in enumerate(ann['caption']):
+ self.text.append(pre_caption(caption,max_words))
+ self.img2txt[img_id].append(txt_id)
+ self.txt2img[txt_id] = img_id
+ txt_id += 1
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ return image, index
\ No newline at end of file
diff --git a/repositories/BLIP/data/flickr30k_dataset.py b/repositories/BLIP/data/flickr30k_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..018ab387014ddaf554c4d3184cfc0e2ba8b2d487
--- /dev/null
+++ b/repositories/BLIP/data/flickr30k_dataset.py
@@ -0,0 +1,93 @@
+import os
+import json
+
+from torch.utils.data import Dataset
+from torchvision.datasets.utils import download_url
+
+from PIL import Image
+
+from data.utils import pre_caption
+
+class flickr30k_train(Dataset):
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
+ '''
+ image_root (string): Root directory of images (e.g. flickr30k/)
+ ann_root (string): directory to store the annotation file
+ '''
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
+ filename = 'flickr30k_train.json'
+
+ download_url(url,ann_root)
+
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
+ self.transform = transform
+ self.image_root = image_root
+ self.max_words = max_words
+ self.prompt = prompt
+
+ self.img_ids = {}
+ n = 0
+ for ann in self.annotation:
+ img_id = ann['image_id']
+ if img_id not in self.img_ids.keys():
+ self.img_ids[img_id] = n
+ n += 1
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.image_root,ann['image'])
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
+
+ return image, caption, self.img_ids[ann['image_id']]
+
+
+class flickr30k_retrieval_eval(Dataset):
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
+ '''
+ image_root (string): Root directory of images (e.g. flickr30k/)
+ ann_root (string): directory to store the annotation file
+ split (string): val or test
+ '''
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
+ filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
+
+ download_url(urls[split],ann_root)
+
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
+ self.transform = transform
+ self.image_root = image_root
+
+ self.text = []
+ self.image = []
+ self.txt2img = {}
+ self.img2txt = {}
+
+ txt_id = 0
+ for img_id, ann in enumerate(self.annotation):
+ self.image.append(ann['image'])
+ self.img2txt[img_id] = []
+ for i, caption in enumerate(ann['caption']):
+ self.text.append(pre_caption(caption,max_words))
+ self.img2txt[img_id].append(txt_id)
+ self.txt2img[txt_id] = img_id
+ txt_id += 1
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ return image, index
\ No newline at end of file
diff --git a/repositories/BLIP/data/nlvr_dataset.py b/repositories/BLIP/data/nlvr_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8d6b2d7cd8d3260bd279c7dca80de53bacc691a
--- /dev/null
+++ b/repositories/BLIP/data/nlvr_dataset.py
@@ -0,0 +1,78 @@
+import os
+import json
+import random
+
+from torch.utils.data import Dataset
+from torchvision.datasets.utils import download_url
+
+from PIL import Image
+
+from data.utils import pre_caption
+
+class nlvr_dataset(Dataset):
+ def __init__(self, transform, image_root, ann_root, split):
+ '''
+ image_root (string): Root directory of images
+ ann_root (string): directory to store the annotation file
+ split (string): train, val or test
+ '''
+ urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
+ 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
+ filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
+
+ download_url(urls[split],ann_root)
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
+
+ self.transform = transform
+ self.image_root = image_root
+
+
+ def __len__(self):
+ return len(self.annotation)
+
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image0_path = os.path.join(self.image_root,ann['images'][0])
+ image0 = Image.open(image0_path).convert('RGB')
+ image0 = self.transform(image0)
+
+ image1_path = os.path.join(self.image_root,ann['images'][1])
+ image1 = Image.open(image1_path).convert('RGB')
+ image1 = self.transform(image1)
+
+ sentence = pre_caption(ann['sentence'], 40)
+
+ if ann['label']=='True':
+ label = 1
+ else:
+ label = 0
+
+ words = sentence.split(' ')
+
+ if 'left' not in words and 'right' not in words:
+ if random.random()<0.5:
+ return image0, image1, sentence, label
+ else:
+ return image1, image0, sentence, label
+ else:
+ if random.random()<0.5:
+ return image0, image1, sentence, label
+ else:
+ new_words = []
+ for word in words:
+ if word=='left':
+ new_words.append('right')
+ elif word=='right':
+ new_words.append('left')
+ else:
+ new_words.append(word)
+
+ sentence = ' '.join(new_words)
+ return image1, image0, sentence, label
+
+
+
\ No newline at end of file
diff --git a/repositories/BLIP/data/nocaps_dataset.py b/repositories/BLIP/data/nocaps_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba0bed06d8af3dbaccf18a56e725f101e585503e
--- /dev/null
+++ b/repositories/BLIP/data/nocaps_dataset.py
@@ -0,0 +1,32 @@
+import os
+import json
+
+from torch.utils.data import Dataset
+from torchvision.datasets.utils import download_url
+
+from PIL import Image
+
+class nocaps_eval(Dataset):
+ def __init__(self, transform, image_root, ann_root, split):
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
+ filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
+
+ download_url(urls[split],ann_root)
+
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
+ self.transform = transform
+ self.image_root = image_root
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image_path = os.path.join(self.image_root,ann['image'])
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ return image, int(ann['img_id'])
\ No newline at end of file
diff --git a/repositories/BLIP/data/pretrain_dataset.py b/repositories/BLIP/data/pretrain_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..703d543ab5267fdc6fe2b7c84ef6a631d8af90ad
--- /dev/null
+++ b/repositories/BLIP/data/pretrain_dataset.py
@@ -0,0 +1,59 @@
+import json
+import os
+import random
+
+from torch.utils.data import Dataset
+
+from PIL import Image
+from PIL import ImageFile
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+Image.MAX_IMAGE_PIXELS = None
+
+from data.utils import pre_caption
+import os,glob
+
+class pretrain_dataset(Dataset):
+ def __init__(self, ann_file, laion_path, transform):
+
+ self.ann_pretrain = []
+ for f in ann_file:
+ print('loading '+f)
+ ann = json.load(open(f,'r'))
+ self.ann_pretrain += ann
+
+ self.laion_path = laion_path
+ if self.laion_path:
+ self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
+
+ print('loading '+self.laion_files[0])
+ with open(self.laion_files[0],'r') as f:
+ self.ann_laion = json.load(f)
+
+ self.annotation = self.ann_pretrain + self.ann_laion
+ else:
+ self.annotation = self.ann_pretrain
+
+ self.transform = transform
+
+
+ def reload_laion(self, epoch):
+ n = epoch%len(self.laion_files)
+ print('loading '+self.laion_files[n])
+ with open(self.laion_files[n],'r') as f:
+ self.ann_laion = json.load(f)
+
+ self.annotation = self.ann_pretrain + self.ann_laion
+
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ image = Image.open(ann['image']).convert('RGB')
+ image = self.transform(image)
+ caption = pre_caption(ann['caption'],30)
+
+ return image, caption
\ No newline at end of file
diff --git a/repositories/BLIP/data/utils.py b/repositories/BLIP/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..628894844becd462d444584b8b2b01a84ee4b8f7
--- /dev/null
+++ b/repositories/BLIP/data/utils.py
@@ -0,0 +1,112 @@
+import re
+import json
+import os
+
+import torch
+import torch.distributed as dist
+
+import utils
+
+def pre_caption(caption,max_words=50):
+ caption = re.sub(
+ r"([.!\"()*#:;~])",
+ ' ',
+ caption.lower(),
+ )
+ caption = re.sub(
+ r"\s{2,}",
+ ' ',
+ caption,
+ )
+ caption = caption.rstrip('\n')
+ caption = caption.strip(' ')
+
+ #truncate caption
+ caption_words = caption.split(' ')
+ if len(caption_words)>max_words:
+ caption = ' '.join(caption_words[:max_words])
+
+ return caption
+
+def pre_question(question,max_ques_words=50):
+ question = re.sub(
+ r"([.!\"()*#:;~])",
+ '',
+ question.lower(),
+ )
+ question = question.rstrip(' ')
+
+ #truncate question
+ question_words = question.split(' ')
+ if len(question_words)>max_ques_words:
+ question = ' '.join(question_words[:max_ques_words])
+
+ return question
+
+
+def save_result(result, result_dir, filename, remove_duplicate=''):
+ result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
+ final_result_file = os.path.join(result_dir, '%s.json'%filename)
+
+ json.dump(result,open(result_file,'w'))
+
+ dist.barrier()
+
+ if utils.is_main_process():
+ # combine results from all processes
+ result = []
+
+ for rank in range(utils.get_world_size()):
+ result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
+ res = json.load(open(result_file,'r'))
+ result += res
+
+ if remove_duplicate:
+ result_new = []
+ id_list = []
+ for res in result:
+ if res[remove_duplicate] not in id_list:
+ id_list.append(res[remove_duplicate])
+ result_new.append(res)
+ result = result_new
+
+ json.dump(result,open(final_result_file,'w'))
+ print('result file saved to %s'%final_result_file)
+
+ return final_result_file
+
+
+
+from pycocotools.coco import COCO
+from pycocoevalcap.eval import COCOEvalCap
+from torchvision.datasets.utils import download_url
+
+def coco_caption_eval(coco_gt_root, results_file, split):
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
+ filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
+
+ download_url(urls[split],coco_gt_root)
+ annotation_file = os.path.join(coco_gt_root,filenames[split])
+
+ # create coco object and coco_result object
+ coco = COCO(annotation_file)
+ coco_result = coco.loadRes(results_file)
+
+ # create coco_eval object by taking coco and coco_result
+ coco_eval = COCOEvalCap(coco, coco_result)
+
+ # evaluate on a subset of images by setting
+ # coco_eval.params['image_id'] = coco_result.getImgIds()
+ # please remove this line when evaluating the full validation set
+ # coco_eval.params['image_id'] = coco_result.getImgIds()
+
+ # evaluate results
+ # SPICE will take a few minutes the first time, but speeds up due to caching
+ coco_eval.evaluate()
+
+ # print output evaluation scores
+ for metric, score in coco_eval.eval.items():
+ print(f'{metric}: {score:.3f}')
+
+ return coco_eval
\ No newline at end of file
diff --git a/repositories/BLIP/data/video_dataset.py b/repositories/BLIP/data/video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a6f8a61105bbd4285f98b3abe9445b73fd4c7ef
--- /dev/null
+++ b/repositories/BLIP/data/video_dataset.py
@@ -0,0 +1,110 @@
+from torch.utils.data import Dataset
+from torchvision.datasets.utils import download_url
+
+from PIL import Image
+import torch
+import numpy as np
+import random
+import decord
+from decord import VideoReader
+import json
+import os
+from data.utils import pre_caption
+
+decord.bridge.set_bridge("torch")
+
+class ImageNorm(object):
+ """Apply Normalization to Image Pixels on GPU
+ """
+ def __init__(self, mean, std):
+ self.mean = torch.tensor(mean).view(1, 3, 1, 1)
+ self.std = torch.tensor(std).view(1, 3, 1, 1)
+
+ def __call__(self, img):
+
+ if torch.max(img) > 1 and self.mean.max() <= 1:
+ img.div_(255.)
+ return img.sub_(self.mean).div_(self.std)
+
+def load_jsonl(filename):
+ with open(filename, "r") as f:
+ return [json.loads(l.strip("\n")) for l in f.readlines()]
+
+
+class VideoDataset(Dataset):
+
+ def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'):
+ '''
+ image_root (string): Root directory of video
+ ann_root (string): directory to store the annotation file
+ '''
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl'
+ filename = 'msrvtt_test.jsonl'
+
+ download_url(url,ann_root)
+ self.annotation = load_jsonl(os.path.join(ann_root,filename))
+
+ self.num_frm = num_frm
+ self.frm_sampling_strategy = frm_sampling_strategy
+ self.max_img_size = max_img_size
+ self.video_root = video_root
+ self.video_fmt = video_fmt
+ self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
+
+ self.text = [pre_caption(ann['caption'],40) for ann in self.annotation]
+ self.txt2video = [i for i in range(len(self.annotation))]
+ self.video2txt = self.txt2video
+
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt)
+
+ vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
+
+ video = self.img_norm(vid_frm_array.float())
+
+ return video, ann['clip_name']
+
+
+
+ def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
+ try:
+ if not height or not width:
+ vr = VideoReader(video_path)
+ else:
+ vr = VideoReader(video_path, width=width, height=height)
+
+ vlen = len(vr)
+
+ if start_time or end_time:
+ assert fps > 0, 'must provide video fps if specifying start and end time.'
+
+ start_idx = min(int(start_time * fps), vlen)
+ end_idx = min(int(end_time * fps), vlen)
+ else:
+ start_idx, end_idx = 0, vlen
+
+ if self.frm_sampling_strategy == 'uniform':
+ frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
+ elif self.frm_sampling_strategy == 'rand':
+ frame_indices = sorted(random.sample(range(vlen), self.num_frm))
+ elif self.frm_sampling_strategy == 'headtail':
+ frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
+ frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
+ frame_indices = frame_indices_head + frame_indices_tail
+ else:
+ raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))
+
+ raw_sample_frms = vr.get_batch(frame_indices)
+ except Exception as e:
+ return None
+
+ raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)
+
+ return raw_sample_frms
diff --git a/repositories/BLIP/data/vqa_dataset.py b/repositories/BLIP/data/vqa_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ec1df429b3910316ddd554bfea01c6e7922cae
--- /dev/null
+++ b/repositories/BLIP/data/vqa_dataset.py
@@ -0,0 +1,88 @@
+import os
+import json
+import random
+from PIL import Image
+
+import torch
+from torch.utils.data import Dataset
+from data.utils import pre_question
+
+from torchvision.datasets.utils import download_url
+
+class vqa_dataset(Dataset):
+ def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
+ self.split = split
+
+ self.transform = transform
+ self.vqa_root = vqa_root
+ self.vg_root = vg_root
+
+ if split=='train':
+ urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
+ 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
+ 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
+
+ self.annotation = []
+ for f in train_files:
+ download_url(urls[f],ann_root)
+ self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
+ else:
+ download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
+ self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
+
+ download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
+ self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
+
+
+ def __len__(self):
+ return len(self.annotation)
+
+ def __getitem__(self, index):
+
+ ann = self.annotation[index]
+
+ if ann['dataset']=='vqa':
+ image_path = os.path.join(self.vqa_root,ann['image'])
+ elif ann['dataset']=='vg':
+ image_path = os.path.join(self.vg_root,ann['image'])
+
+ image = Image.open(image_path).convert('RGB')
+ image = self.transform(image)
+
+ if self.split == 'test':
+ question = pre_question(ann['question'])
+ question_id = ann['question_id']
+ return image, question, question_id
+
+
+ elif self.split=='train':
+
+ question = pre_question(ann['question'])
+
+ if ann['dataset']=='vqa':
+ answer_weight = {}
+ for answer in ann['answer']:
+ if answer in answer_weight.keys():
+ answer_weight[answer] += 1/len(ann['answer'])
+ else:
+ answer_weight[answer] = 1/len(ann['answer'])
+
+ answers = list(answer_weight.keys())
+ weights = list(answer_weight.values())
+
+ elif ann['dataset']=='vg':
+ answers = [ann['answer']]
+ weights = [0.2]
+
+ return image, question, answers, weights
+
+
+def vqa_collate_fn(batch):
+ image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
+ for image, question, answer, weights in batch:
+ image_list.append(image)
+ question_list.append(question)
+ weight_list += weights
+ answer_list += answer
+ n.append(len(answer))
+ return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
\ No newline at end of file
diff --git a/repositories/BLIP/demo.ipynb b/repositories/BLIP/demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..40207f52ad49f2f3a00c798cf405b215e2e90fba
--- /dev/null
+++ b/repositories/BLIP/demo.ipynb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c86d1f76a227c4f6ea711e8eb8743b23cfac100058a3637e795d5e6dccfeeae7
+size 687381
diff --git a/repositories/BLIP/eval_nocaps.py b/repositories/BLIP/eval_nocaps.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbb09a8cc7771605c013583d721aa95d9413b42
--- /dev/null
+++ b/repositories/BLIP/eval_nocaps.py
@@ -0,0 +1,118 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+
+from models.blip import blip_decoder
+import utils
+from data import create_dataset, create_sampler, create_loader
+from data.utils import save_result
+
+@torch.no_grad()
+def evaluate(model, data_loader, device, config):
+ # evaluate
+ model.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Evaluation:'
+ print_freq = 10
+
+ result = []
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
+
+ image = image.to(device)
+
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
+ min_length=config['min_length'], repetition_penalty=1.1)
+
+ for caption, img_id in zip(captions, image_id):
+ result.append({"image_id": img_id.item(), "caption": caption})
+
+ return result
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating captioning dataset")
+ val_dataset, test_dataset = create_dataset('nocaps', config)
+
+ if args.distributed:
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
+ else:
+ samplers = [None,None]
+
+ val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
+ batch_size=[config['batch_size']]*2,num_workers=[4,4],
+ is_trains=[False, False], collate_fns=[None,None])
+
+ #### Model ####
+ print("Creating model")
+ model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
+ prompt=config['prompt'])
+
+ model = model.to(device)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
+ val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
+ test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/nocaps.yaml')
+ parser.add_argument('--output_dir', default='output/NoCaps')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ args.result_dir = os.path.join(args.output_dir, 'result')
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/repositories/BLIP/eval_retrieval_video.py b/repositories/BLIP/eval_retrieval_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..07ebab7f41f6466f6f46130002e2e0df1266486a
--- /dev/null
+++ b/repositories/BLIP/eval_retrieval_video.py
@@ -0,0 +1,250 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+
+from models.blip_retrieval import blip_retrieval
+import utils
+from data.video_dataset import VideoDataset
+
+
+@torch.no_grad()
+def evaluation(model, data_loader, tokenizer, device, config):
+ # test
+ model.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Evaluation:'
+
+ print('Computing features for evaluation...')
+ start_time = time.time()
+
+ texts = data_loader.dataset.text
+ num_text = len(texts)
+ text_bs = 256
+ text_ids = []
+ text_embeds = []
+ text_atts = []
+ for i in range(0, num_text, text_bs):
+ text = texts[i: min(num_text, i+text_bs)]
+ text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
+ text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
+ text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
+ text_embeds.append(text_embed)
+ text_ids.append(text_input.input_ids)
+ text_atts.append(text_input.attention_mask)
+
+ text_embeds = torch.cat(text_embeds,dim=0)
+ text_ids = torch.cat(text_ids,dim=0)
+ text_atts = torch.cat(text_atts,dim=0)
+ text_ids[:,0] = tokenizer.additional_special_tokens_ids[0]
+
+ video_feats = []
+ video_embeds = []
+ for video, video_id in data_loader:
+
+ B,N,C,W,H = video.size()
+ video = video.view(-1,C,W,H)
+ video = video.to(device,non_blocking=True)
+ video_feat = model.visual_encoder(video)
+ video_embed = model.vision_proj(video_feat[:,0,:])
+ video_embed = video_embed.view(B,N,-1).mean(dim=1)
+ video_embed = F.normalize(video_embed,dim=-1)
+
+ video_feat = video_feat.view(B,-1,video_feat.shape[-1])
+ video_feats.append(video_feat.cpu())
+ video_embeds.append(video_embed)
+
+ video_feats = torch.cat(video_feats,dim=0)
+ video_embeds = torch.cat(video_embeds,dim=0)
+
+ sims_matrix = video_embeds @ text_embeds.t()
+ score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device)
+
+ num_tasks = utils.get_world_size()
+ rank = utils.get_rank()
+ step = sims_matrix.size(0)//num_tasks + 1
+ start = rank*step
+ end = min(sims_matrix.size(0),start+step)
+
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
+
+ encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True)
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
+ output = model.text_encoder(text_ids[topk_idx],
+ attention_mask = text_atts[topk_idx],
+ encoder_hidden_states = encoder_output,
+ encoder_attention_mask = encoder_att,
+ return_dict = True,
+ )
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
+ score_matrix_v2t[start+i,topk_idx] = score + topk_sim
+
+ sims_matrix = sims_matrix.t()
+ score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device)
+
+ step = sims_matrix.size(0)//num_tasks + 1
+ start = rank*step
+ end = min(sims_matrix.size(0),start+step)
+
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
+
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
+ encoder_output = video_feats[topk_idx].to(device,non_blocking=True)
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
+ output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
+ attention_mask = text_atts[start+i].repeat(config['k_test'],1),
+ encoder_hidden_states = encoder_output,
+ encoder_attention_mask = encoder_att,
+ return_dict = True,
+ )
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
+ score_matrix_t2v[start+i,topk_idx] = score + topk_sim
+
+ if args.distributed:
+ dist.barrier()
+ torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM)
+ torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Evaluation time {}'.format(total_time_str))
+
+ return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()
+
+
+
+@torch.no_grad()
+def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt):
+
+ #Video->Text
+ ranks = np.zeros(scores_v2t.shape[0])
+ for index,score in enumerate(scores_v2t):
+ inds = np.argsort(score)[::-1]
+ ranks[index] = np.where(inds == vid2txt[index])[0][0]
+
+ # Compute metrics
+ tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
+ tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
+ tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
+
+ #Text->Video
+ ranks = np.zeros(scores_t2v.shape[0])
+
+ for index,score in enumerate(scores_t2v):
+ inds = np.argsort(score)[::-1]
+ ranks[index] = np.where(inds == txt2vmg[index])[0][0]
+
+ mdR = np.median(ranks+1)
+
+ # Compute metrics
+ vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
+ vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
+ vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
+
+ tr_mean = (tr1 + tr5 + tr10) / 3
+ vr_mean = (vr1 + vr5 + vr10) / 3
+ r_mean = (tr_mean + vr_mean) / 2
+
+ eval_result = {'txt_r1': tr1,
+ 'txt_r5': tr5,
+ 'txt_r10': tr10,
+ 'txt_r_mean': tr_mean,
+ 'vid_r1': vr1,
+ 'vid_r5': vr5,
+ 'vid_r10': vr10,
+ 'vid_r_mean': vr_mean,
+ 'vid_mdR': mdR,
+ 'r_mean': r_mean}
+ return eval_result
+
+
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating retrieval dataset")
+ test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'],
+ max_img_size=config['image_size'], frm_sampling_strategy='uniform')
+
+ test_loader = DataLoader(
+ test_dataset,
+ batch_size=config['batch_size'],
+ num_workers=4,
+ pin_memory=True,
+ drop_last=False,
+ shuffle=False,
+ )
+
+ #### Model ####
+ print("Creating model")
+ model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'])
+
+ model = model.to(device)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config)
+
+ if utils.is_main_process():
+
+ test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt)
+ print(test_result)
+
+ log_stats = {**{f'{k}': v for k, v in test_result.items()},}
+ with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml')
+ parser.add_argument('--output_dir', default='output/Retrieval_msrvtt')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/repositories/BLIP/models/__init__.py b/repositories/BLIP/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/BLIP/models/blip.py b/repositories/BLIP/models/blip.py
new file mode 100644
index 0000000000000000000000000000000000000000..38678f65ea2c276b351c2c97d429ebc2525ddcf7
--- /dev/null
+++ b/repositories/BLIP/models/blip.py
@@ -0,0 +1,238 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import warnings
+warnings.filterwarnings("ignore")
+
+from models.vit import VisionTransformer, interpolate_pos_embed
+from models.med import BertConfig, BertModel, BertLMHeadModel
+from transformers import BertTokenizer
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+import os
+from urllib.parse import urlparse
+from timm.models.hub import download_cached_file
+
+class BLIP_Base(nn.Module):
+ def __init__(self,
+ med_config = 'configs/med_config.json',
+ image_size = 224,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
+ self.tokenizer = init_tokenizer()
+ med_config = BertConfig.from_json_file(med_config)
+ med_config.encoder_width = vision_width
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
+
+
+ def forward(self, image, caption, mode):
+
+ assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
+ text = self.tokenizer(caption, return_tensors="pt").to(image.device)
+
+ if mode=='image':
+ # return image features
+ image_embeds = self.visual_encoder(image)
+ return image_embeds
+
+ elif mode=='text':
+ # return text features
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
+ return_dict = True, mode = 'text')
+ return text_output.last_hidden_state
+
+ elif mode=='multimodal':
+ # return multimodel features
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
+ output = self.text_encoder(text.input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True,
+ )
+ return output.last_hidden_state
+
+
+
+class BLIP_Decoder(nn.Module):
+ def __init__(self,
+ med_config = 'configs/med_config.json',
+ image_size = 384,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ prompt = 'a picture of ',
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
+ self.tokenizer = init_tokenizer()
+ med_config = BertConfig.from_json_file(med_config)
+ med_config.encoder_width = vision_width
+ self.text_decoder = BertLMHeadModel(config=med_config)
+
+ self.prompt = prompt
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
+
+
+ def forward(self, image, caption):
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
+
+ text.input_ids[:,0] = self.tokenizer.bos_token_id
+
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
+ decoder_targets[:,:self.prompt_length] = -100
+
+ decoder_output = self.text_decoder(text.input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ labels = decoder_targets,
+ return_dict = True,
+ )
+ loss_lm = decoder_output.loss
+
+ return loss_lm
+
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
+ image_embeds = self.visual_encoder(image)
+
+ if not sample:
+ image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
+
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
+
+ prompt = [self.prompt] * image.size(0)
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
+ input_ids[:,0] = self.tokenizer.bos_token_id
+ input_ids = input_ids[:, :-1]
+
+ if sample:
+ #nucleus sampling
+ outputs = self.text_decoder.generate(input_ids=input_ids,
+ max_length=max_length,
+ min_length=min_length,
+ do_sample=True,
+ top_p=top_p,
+ num_return_sequences=1,
+ eos_token_id=self.tokenizer.sep_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ repetition_penalty=1.1,
+ **model_kwargs)
+ else:
+ #beam search
+ outputs = self.text_decoder.generate(input_ids=input_ids,
+ max_length=max_length,
+ min_length=min_length,
+ num_beams=num_beams,
+ eos_token_id=self.tokenizer.sep_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ repetition_penalty=repetition_penalty,
+ **model_kwargs)
+
+ captions = []
+ for output in outputs:
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
+ captions.append(caption[len(self.prompt):])
+ return captions
+
+
+def blip_decoder(pretrained='',**kwargs):
+ model = BLIP_Decoder(**kwargs)
+ if pretrained:
+ model,msg = load_checkpoint(model,pretrained)
+ assert(len(msg.missing_keys)==0)
+ return model
+
+def blip_feature_extractor(pretrained='',**kwargs):
+ model = BLIP_Base(**kwargs)
+ if pretrained:
+ model,msg = load_checkpoint(model,pretrained)
+ assert(len(msg.missing_keys)==0)
+ return model
+
+def init_tokenizer():
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
+ return tokenizer
+
+
+def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
+
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
+ if vit=='base':
+ vision_width = 768
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
+ drop_path_rate=0 or drop_path_rate
+ )
+ elif vit=='large':
+ vision_width = 1024
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
+ drop_path_rate=0.1 or drop_path_rate
+ )
+ return visual_encoder, vision_width
+
+def is_url(url_or_filename):
+ parsed = urlparse(url_or_filename)
+ return parsed.scheme in ("http", "https")
+
+def load_checkpoint(model,url_or_filename):
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
+ checkpoint = torch.load(cached_file, map_location='cpu')
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
+ else:
+ raise RuntimeError('checkpoint url or path is invalid')
+
+ state_dict = checkpoint['model']
+
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
+ model.visual_encoder_m)
+ for key in model.state_dict().keys():
+ if key in state_dict.keys():
+ if state_dict[key].shape!=model.state_dict()[key].shape:
+ del state_dict[key]
+
+ msg = model.load_state_dict(state_dict,strict=False)
+ print('load checkpoint from %s'%url_or_filename)
+ return model,msg
+
diff --git a/repositories/BLIP/models/blip_itm.py b/repositories/BLIP/models/blip_itm.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf354c829564bf5a1f56089a2d745093d51e0fa2
--- /dev/null
+++ b/repositories/BLIP/models/blip_itm.py
@@ -0,0 +1,76 @@
+from models.med import BertConfig, BertModel
+from transformers import BertTokenizer
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from models.blip import create_vit, init_tokenizer, load_checkpoint
+
+class BLIP_ITM(nn.Module):
+ def __init__(self,
+ med_config = 'configs/med_config.json',
+ image_size = 384,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ embed_dim = 256,
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
+ self.tokenizer = init_tokenizer()
+ med_config = BertConfig.from_json_file(med_config)
+ med_config.encoder_width = vision_width
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
+
+ text_width = self.text_encoder.config.hidden_size
+
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
+ self.text_proj = nn.Linear(text_width, embed_dim)
+
+ self.itm_head = nn.Linear(text_width, 2)
+
+
+ def forward(self, image, caption, match_head='itm'):
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
+ return_tensors="pt").to(image.device)
+
+
+ if match_head=='itm':
+ output = self.text_encoder(text.input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True,
+ )
+ itm_output = self.itm_head(output.last_hidden_state[:,0,:])
+ return itm_output
+
+ elif match_head=='itc':
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
+ return_dict = True, mode = 'text')
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
+
+ sim = image_feat @ text_feat.t()
+ return sim
+
+
+def blip_itm(pretrained='',**kwargs):
+ model = BLIP_ITM(**kwargs)
+ if pretrained:
+ model,msg = load_checkpoint(model,pretrained)
+ assert(len(msg.missing_keys)==0)
+ return model
+
\ No newline at end of file
diff --git a/repositories/BLIP/models/blip_nlvr.py b/repositories/BLIP/models/blip_nlvr.py
new file mode 100644
index 0000000000000000000000000000000000000000..84837167bfa6874d3c3e41fb9b37271113910b7f
--- /dev/null
+++ b/repositories/BLIP/models/blip_nlvr.py
@@ -0,0 +1,103 @@
+from models.med import BertConfig
+from models.nlvr_encoder import BertModel
+from models.vit import interpolate_pos_embed
+from models.blip import create_vit, init_tokenizer, is_url
+
+from timm.models.hub import download_cached_file
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from transformers import BertTokenizer
+import numpy as np
+
+class BLIP_NLVR(nn.Module):
+ def __init__(self,
+ med_config = 'configs/med_config.json',
+ image_size = 480,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
+ self.tokenizer = init_tokenizer()
+ med_config = BertConfig.from_json_file(med_config)
+ med_config.encoder_width = vision_width
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
+
+ self.cls_head = nn.Sequential(
+ nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
+ nn.ReLU(),
+ nn.Linear(self.text_encoder.config.hidden_size, 2)
+ )
+
+ def forward(self, image, text, targets, train=True):
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+ image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
+
+ text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device)
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
+
+ output = self.text_encoder(text.input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = [image0_embeds,image1_embeds],
+ encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
+ image_atts[image0_embeds.size(0):]],
+ return_dict = True,
+ )
+ hidden_state = output.last_hidden_state[:,0,:]
+ prediction = self.cls_head(hidden_state)
+
+ if train:
+ loss = F.cross_entropy(prediction, targets)
+ return loss
+ else:
+ return prediction
+
+def blip_nlvr(pretrained='',**kwargs):
+ model = BLIP_NLVR(**kwargs)
+ if pretrained:
+ model,msg = load_checkpoint(model,pretrained)
+ print("missing keys:")
+ print(msg.missing_keys)
+ return model
+
+
+def load_checkpoint(model,url_or_filename):
+ if is_url(url_or_filename):
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
+ checkpoint = torch.load(cached_file, map_location='cpu')
+ elif os.path.isfile(url_or_filename):
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
+ else:
+ raise RuntimeError('checkpoint url or path is invalid')
+ state_dict = checkpoint['model']
+
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
+
+ for key in list(state_dict.keys()):
+ if 'crossattention.self.' in key:
+ new_key0 = key.replace('self','self0')
+ new_key1 = key.replace('self','self1')
+ state_dict[new_key0] = state_dict[key]
+ state_dict[new_key1] = state_dict[key]
+ elif 'crossattention.output.dense.' in key:
+ new_key0 = key.replace('dense','dense0')
+ new_key1 = key.replace('dense','dense1')
+ state_dict[new_key0] = state_dict[key]
+ state_dict[new_key1] = state_dict[key]
+
+ msg = model.load_state_dict(state_dict,strict=False)
+ print('load checkpoint from %s'%url_or_filename)
+ return model,msg
+
\ No newline at end of file
diff --git a/repositories/BLIP/models/blip_pretrain.py b/repositories/BLIP/models/blip_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..e42ce5f998b0a51e6f731ee6b5c8bae6d02a8664
--- /dev/null
+++ b/repositories/BLIP/models/blip_pretrain.py
@@ -0,0 +1,339 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+from models.med import BertConfig, BertModel, BertLMHeadModel
+from transformers import BertTokenizer
+import transformers
+transformers.logging.set_verbosity_error()
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from models.blip import create_vit, init_tokenizer, load_checkpoint
+
+class BLIP_Pretrain(nn.Module):
+ def __init__(self,
+ med_config = 'configs/bert_config.json',
+ image_size = 224,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ embed_dim = 256,
+ queue_size = 57600,
+ momentum = 0.995,
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
+
+ if vit=='base':
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
+ map_location="cpu", check_hash=True)
+ state_dict = checkpoint["model"]
+ msg = self.visual_encoder.load_state_dict(state_dict,strict=False)
+ elif vit=='large':
+ from timm.models.helpers import load_custom_pretrained
+ from timm.models.vision_transformer import default_cfgs
+ load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k'])
+
+ self.tokenizer = init_tokenizer()
+ encoder_config = BertConfig.from_json_file(med_config)
+ encoder_config.encoder_width = vision_width
+ self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False)
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
+
+ text_width = self.text_encoder.config.hidden_size
+
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
+ self.text_proj = nn.Linear(text_width, embed_dim)
+
+ self.itm_head = nn.Linear(text_width, 2)
+
+ # create momentum encoders
+ self.visual_encoder_m, vision_width = create_vit(vit,image_size)
+ self.vision_proj_m = nn.Linear(vision_width, embed_dim)
+ self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False)
+ self.text_proj_m = nn.Linear(text_width, embed_dim)
+
+ self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
+ [self.vision_proj,self.vision_proj_m],
+ [self.text_encoder,self.text_encoder_m],
+ [self.text_proj,self.text_proj_m],
+ ]
+ self.copy_params()
+
+ # create the queue
+ self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
+ self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
+ self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
+
+ self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
+ self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
+
+ self.queue_size = queue_size
+ self.momentum = momentum
+ self.temp = nn.Parameter(0.07*torch.ones([]))
+
+ # create the decoder
+ decoder_config = BertConfig.from_json_file(med_config)
+ decoder_config.encoder_width = vision_width
+ self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
+ self.text_decoder.resize_token_embeddings(len(self.tokenizer))
+ tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
+
+
+ def forward(self, image, caption, alpha):
+ with torch.no_grad():
+ self.temp.clamp_(0.001,0.5)
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
+
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30,
+ return_tensors="pt").to(image.device)
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
+ return_dict = True, mode = 'text')
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
+
+ # get momentum features
+ with torch.no_grad():
+ self._momentum_update()
+ image_embeds_m = self.visual_encoder_m(image)
+ image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
+ image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
+
+ text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
+ return_dict = True, mode = 'text')
+ text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
+ text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
+
+ sim_i2t_m = image_feat_m @ text_feat_all / self.temp
+ sim_t2i_m = text_feat_m @ image_feat_all / self.temp
+
+ sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device)
+ sim_targets.fill_diagonal_(1)
+
+ sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
+ sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
+
+ sim_i2t = image_feat @ text_feat_all / self.temp
+ sim_t2i = text_feat @ image_feat_all / self.temp
+
+ loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
+ loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
+
+ loss_ita = (loss_i2t+loss_t2i)/2
+
+ self._dequeue_and_enqueue(image_feat_m, text_feat_m)
+
+ ###============== Image-text Matching ===================###
+ encoder_input_ids = text.input_ids.clone()
+ encoder_input_ids[:,0] = self.tokenizer.enc_token_id
+
+ # forward the positve image-text pair
+ bs = image.size(0)
+ output_pos = self.text_encoder(encoder_input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True,
+ )
+ with torch.no_grad():
+ weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4
+ weights_t2i.fill_diagonal_(0)
+ weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4
+ weights_i2t.fill_diagonal_(0)
+
+ # select a negative image for each text
+ image_embeds_neg = []
+ for b in range(bs):
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
+ image_embeds_neg.append(image_embeds[neg_idx])
+ image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
+
+ # select a negative text for each image
+ text_ids_neg = []
+ text_atts_neg = []
+ for b in range(bs):
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
+ text_ids_neg.append(encoder_input_ids[neg_idx])
+ text_atts_neg.append(text.attention_mask[neg_idx])
+
+ text_ids_neg = torch.stack(text_ids_neg,dim=0)
+ text_atts_neg = torch.stack(text_atts_neg,dim=0)
+
+ text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
+ text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
+
+ image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
+ image_atts_all = torch.cat([image_atts,image_atts],dim=0)
+
+ output_neg = self.text_encoder(text_ids_all,
+ attention_mask = text_atts_all,
+ encoder_hidden_states = image_embeds_all,
+ encoder_attention_mask = image_atts_all,
+ return_dict = True,
+ )
+
+ vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
+ vl_output = self.itm_head(vl_embeddings)
+
+ itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
+ dim=0).to(image.device)
+ loss_itm = F.cross_entropy(vl_output, itm_labels)
+
+ ##================= LM ========================##
+ decoder_input_ids = text.input_ids.clone()
+ decoder_input_ids[:,0] = self.tokenizer.bos_token_id
+ decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)
+
+ decoder_output = self.text_decoder(decoder_input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ labels = decoder_targets,
+ return_dict = True,
+ )
+
+ loss_lm = decoder_output.loss
+ return loss_ita, loss_itm, loss_lm
+
+
+
+ @torch.no_grad()
+ def copy_params(self):
+ for model_pair in self.model_pairs:
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
+ param_m.data.copy_(param.data) # initialize
+ param_m.requires_grad = False # not update by gradient
+
+
+ @torch.no_grad()
+ def _momentum_update(self):
+ for model_pair in self.model_pairs:
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
+ param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
+
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self, image_feat, text_feat):
+ # gather keys before updating queue
+ image_feats = concat_all_gather(image_feat)
+ text_feats = concat_all_gather(text_feat)
+
+ batch_size = image_feats.shape[0]
+
+ ptr = int(self.queue_ptr)
+ assert self.queue_size % batch_size == 0 # for simplicity
+
+ # replace the keys at ptr (dequeue and enqueue)
+ self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
+ self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
+
+ self.queue_ptr[0] = ptr
+
+
+def blip_pretrain(**kwargs):
+ model = BLIP_Pretrain(**kwargs)
+ return model
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ tensors_gather = [torch.ones_like(tensor)
+ for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+
+from typing import List
+def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
+ uninitialized_encoder_weights: List[str] = []
+ if decoder.__class__ != encoder.__class__:
+ logger.info(
+ f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
+ )
+
+ def tie_encoder_to_decoder_recursively(
+ decoder_pointer: nn.Module,
+ encoder_pointer: nn.Module,
+ module_name: str,
+ uninitialized_encoder_weights: List[str],
+ skip_key: str,
+ depth=0,
+ ):
+ assert isinstance(decoder_pointer, nn.Module) and isinstance(
+ encoder_pointer, nn.Module
+ ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
+ if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
+ assert hasattr(encoder_pointer, "weight")
+ encoder_pointer.weight = decoder_pointer.weight
+ if hasattr(decoder_pointer, "bias"):
+ assert hasattr(encoder_pointer, "bias")
+ encoder_pointer.bias = decoder_pointer.bias
+ print(module_name+' is tied')
+ return
+
+ encoder_modules = encoder_pointer._modules
+ decoder_modules = decoder_pointer._modules
+ if len(decoder_modules) > 0:
+ assert (
+ len(encoder_modules) > 0
+ ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
+
+ all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
+ encoder_layer_pos = 0
+ for name, module in decoder_modules.items():
+ if name.isdigit():
+ encoder_name = str(int(name) + encoder_layer_pos)
+ decoder_name = name
+ if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
+ encoder_modules
+ ) != len(decoder_modules):
+ # this can happen if the name corresponds to the position in a list module list of layers
+ # in this case the decoder has added a cross-attention that the encoder does not have
+ # thus skip this step and subtract one layer pos from encoder
+ encoder_layer_pos -= 1
+ continue
+ elif name not in encoder_modules:
+ continue
+ elif depth > 500:
+ raise ValueError(
+ "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
+ )
+ else:
+ decoder_name = encoder_name = name
+ tie_encoder_to_decoder_recursively(
+ decoder_modules[decoder_name],
+ encoder_modules[encoder_name],
+ module_name + "/" + name,
+ uninitialized_encoder_weights,
+ skip_key,
+ depth=depth + 1,
+ )
+ all_encoder_weights.remove(module_name + "/" + encoder_name)
+
+ uninitialized_encoder_weights += list(all_encoder_weights)
+
+ # tie weights recursively
+ tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
diff --git a/repositories/BLIP/models/blip_retrieval.py b/repositories/BLIP/models/blip_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..1debe7e2e664f8dd603f8d4c537e3599c68638d7
--- /dev/null
+++ b/repositories/BLIP/models/blip_retrieval.py
@@ -0,0 +1,319 @@
+from models.med import BertConfig, BertModel
+from transformers import BertTokenizer
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from models.blip import create_vit, init_tokenizer, load_checkpoint
+
+class BLIP_Retrieval(nn.Module):
+ def __init__(self,
+ med_config = 'configs/med_config.json',
+ image_size = 384,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ embed_dim = 256,
+ queue_size = 57600,
+ momentum = 0.995,
+ negative_all_rank = False,
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
+ self.tokenizer = init_tokenizer()
+ med_config = BertConfig.from_json_file(med_config)
+ med_config.encoder_width = vision_width
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
+
+ text_width = self.text_encoder.config.hidden_size
+
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
+ self.text_proj = nn.Linear(text_width, embed_dim)
+
+ self.itm_head = nn.Linear(text_width, 2)
+
+ # create momentum encoders
+ self.visual_encoder_m, vision_width = create_vit(vit,image_size)
+ self.vision_proj_m = nn.Linear(vision_width, embed_dim)
+ self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False)
+ self.text_proj_m = nn.Linear(text_width, embed_dim)
+
+ self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
+ [self.vision_proj,self.vision_proj_m],
+ [self.text_encoder,self.text_encoder_m],
+ [self.text_proj,self.text_proj_m],
+ ]
+ self.copy_params()
+
+ # create the queue
+ self.register_buffer("image_queue", torch.randn(embed_dim, queue_size))
+ self.register_buffer("text_queue", torch.randn(embed_dim, queue_size))
+ self.register_buffer("idx_queue", torch.full((1,queue_size),-100))
+ self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long))
+
+ self.image_queue = nn.functional.normalize(self.image_queue, dim=0)
+ self.text_queue = nn.functional.normalize(self.text_queue, dim=0)
+
+ self.queue_size = queue_size
+ self.momentum = momentum
+ self.temp = nn.Parameter(0.07*torch.ones([]))
+
+ self.negative_all_rank = negative_all_rank
+
+
+ def forward(self, image, caption, alpha, idx):
+ with torch.no_grad():
+ self.temp.clamp_(0.001,0.5)
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+ image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
+
+ text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
+ return_tensors="pt").to(image.device)
+
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
+ return_dict = True, mode = 'text')
+ text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
+
+ ###============== Image-text Contrastive Learning ===================###
+ idx = idx.view(-1,1)
+ idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1)
+ pos_idx = torch.eq(idx, idx_all).float()
+ sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
+
+ # get momentum features
+ with torch.no_grad():
+ self._momentum_update()
+ image_embeds_m = self.visual_encoder_m(image)
+ image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1)
+ image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1)
+
+ text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask,
+ return_dict = True, mode = 'text')
+ text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1)
+ text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1)
+
+ sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp
+ sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp
+
+ sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets
+ sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets
+
+ sim_i2t = image_feat @ text_feat_m_all / self.temp
+ sim_t2i = text_feat @ image_feat_m_all / self.temp
+
+ loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean()
+ loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean()
+
+ loss_ita = (loss_i2t+loss_t2i)/2
+
+ idxs = concat_all_gather(idx)
+ self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs)
+
+ ###============== Image-text Matching ===================###
+ encoder_input_ids = text.input_ids.clone()
+ encoder_input_ids[:,0] = self.tokenizer.enc_token_id
+
+ # forward the positve image-text pair
+ bs = image.size(0)
+ output_pos = self.text_encoder(encoder_input_ids,
+ attention_mask = text.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True,
+ )
+
+
+ if self.negative_all_rank:
+ # compute sample similarity
+ with torch.no_grad():
+ mask = torch.eq(idx, idxs.t())
+
+ image_feat_world = concat_all_gather(image_feat)
+ text_feat_world = concat_all_gather(text_feat)
+
+ sim_i2t = image_feat @ text_feat_world.t() / self.temp
+ sim_t2i = text_feat @ image_feat_world.t() / self.temp
+
+ weights_i2t = F.softmax(sim_i2t,dim=1)
+ weights_i2t.masked_fill_(mask, 0)
+
+ weights_t2i = F.softmax(sim_t2i,dim=1)
+ weights_t2i.masked_fill_(mask, 0)
+
+ image_embeds_world = all_gather_with_grad(image_embeds)
+
+ # select a negative image (from all ranks) for each text
+ image_embeds_neg = []
+ for b in range(bs):
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
+ image_embeds_neg.append(image_embeds_world[neg_idx])
+ image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
+
+ # select a negative text (from all ranks) for each image
+ input_ids_world = concat_all_gather(encoder_input_ids)
+ att_mask_world = concat_all_gather(text.attention_mask)
+
+ text_ids_neg = []
+ text_atts_neg = []
+ for b in range(bs):
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
+ text_ids_neg.append(input_ids_world[neg_idx])
+ text_atts_neg.append(att_mask_world[neg_idx])
+
+ else:
+ with torch.no_grad():
+ mask = torch.eq(idx, idx.t())
+
+ sim_i2t = image_feat @ text_feat.t() / self.temp
+ sim_t2i = text_feat @ image_feat.t() / self.temp
+
+ weights_i2t = F.softmax(sim_i2t,dim=1)
+ weights_i2t.masked_fill_(mask, 0)
+
+ weights_t2i = F.softmax(sim_t2i,dim=1)
+ weights_t2i.masked_fill_(mask, 0)
+
+ # select a negative image (from same rank) for each text
+ image_embeds_neg = []
+ for b in range(bs):
+ neg_idx = torch.multinomial(weights_t2i[b], 1).item()
+ image_embeds_neg.append(image_embeds[neg_idx])
+ image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
+
+ # select a negative text (from same rank) for each image
+ text_ids_neg = []
+ text_atts_neg = []
+ for b in range(bs):
+ neg_idx = torch.multinomial(weights_i2t[b], 1).item()
+ text_ids_neg.append(encoder_input_ids[neg_idx])
+ text_atts_neg.append(text.attention_mask[neg_idx])
+
+ text_ids_neg = torch.stack(text_ids_neg,dim=0)
+ text_atts_neg = torch.stack(text_atts_neg,dim=0)
+
+ text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0)
+ text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0)
+
+ image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
+ image_atts_all = torch.cat([image_atts,image_atts],dim=0)
+
+ output_neg = self.text_encoder(text_ids_all,
+ attention_mask = text_atts_all,
+ encoder_hidden_states = image_embeds_all,
+ encoder_attention_mask = image_atts_all,
+ return_dict = True,
+ )
+
+
+ vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0)
+ vl_output = self.itm_head(vl_embeddings)
+
+ itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)],
+ dim=0).to(image.device)
+ loss_itm = F.cross_entropy(vl_output, itm_labels)
+
+ return loss_ita, loss_itm
+
+
+ @torch.no_grad()
+ def copy_params(self):
+ for model_pair in self.model_pairs:
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
+ param_m.data.copy_(param.data) # initialize
+ param_m.requires_grad = False # not update by gradient
+
+
+ @torch.no_grad()
+ def _momentum_update(self):
+ for model_pair in self.model_pairs:
+ for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
+ param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
+
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs):
+ # gather keys before updating queue
+ image_feats = concat_all_gather(image_feat)
+ text_feats = concat_all_gather(text_feat)
+
+
+ batch_size = image_feats.shape[0]
+
+ ptr = int(self.ptr_queue)
+ assert self.queue_size % batch_size == 0 # for simplicity
+
+ # replace the keys at ptr (dequeue and enqueue)
+ self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
+ self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
+ self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
+
+ self.ptr_queue[0] = ptr
+
+
+def blip_retrieval(pretrained='',**kwargs):
+ model = BLIP_Retrieval(**kwargs)
+ if pretrained:
+ model,msg = load_checkpoint(model,pretrained)
+ print("missing keys:")
+ print(msg.missing_keys)
+ return model
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ tensors_gather = [torch.ones_like(tensor)
+ for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+ output = torch.cat(tensors_gather, dim=0)
+ return output
+
+
+class GatherLayer(torch.autograd.Function):
+ """
+ Gather tensors from all workers with support for backward propagation:
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(output, x)
+ return tuple(output)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ all_gradients = torch.stack(grads)
+ torch.distributed.all_reduce(all_gradients)
+ return all_gradients[torch.distributed.get_rank()]
+
+
+def all_gather_with_grad(tensors):
+ """
+ Performs all_gather operation on the provided tensors.
+ Graph remains connected for backward grad computation.
+ """
+ # Queue the gathered tensors
+ world_size = torch.distributed.get_world_size()
+ # There is no need for reduction in the single-proc case
+ if world_size == 1:
+ return tensors
+
+ tensor_all = GatherLayer.apply(tensors)
+
+ return torch.cat(tensor_all, dim=0)
diff --git a/repositories/BLIP/models/blip_vqa.py b/repositories/BLIP/models/blip_vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4cb3688fad03888f8568ec65437ee20452c6cb8
--- /dev/null
+++ b/repositories/BLIP/models/blip_vqa.py
@@ -0,0 +1,186 @@
+from models.med import BertConfig, BertModel, BertLMHeadModel
+from models.blip import create_vit, init_tokenizer, load_checkpoint
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from transformers import BertTokenizer
+import numpy as np
+
+class BLIP_VQA(nn.Module):
+ def __init__(self,
+ med_config = 'configs/med_config.json',
+ image_size = 480,
+ vit = 'base',
+ vit_grad_ckpt = False,
+ vit_ckpt_layer = 0,
+ ):
+ """
+ Args:
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
+ image_size (int): input image size
+ vit (str): model size of vision transformer
+ """
+ super().__init__()
+
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1)
+ self.tokenizer = init_tokenizer()
+
+ encoder_config = BertConfig.from_json_file(med_config)
+ encoder_config.encoder_width = vision_width
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
+
+ decoder_config = BertConfig.from_json_file(med_config)
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
+
+
+ def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
+
+ image_embeds = self.visual_encoder(image)
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
+
+ question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
+ return_tensors="pt").to(image.device)
+ question.input_ids[:,0] = self.tokenizer.enc_token_id
+
+ if train:
+ '''
+ n: number of answers for each question
+ weights: weight for each answer
+ '''
+ answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device)
+ answer.input_ids[:,0] = self.tokenizer.bos_token_id
+ answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
+
+ question_output = self.text_encoder(question.input_ids,
+ attention_mask = question.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True)
+
+ question_states = []
+ question_atts = []
+ for b, n in enumerate(n):
+ question_states += [question_output.last_hidden_state[b]]*n
+ question_atts += [question.attention_mask[b]]*n
+ question_states = torch.stack(question_states,0)
+ question_atts = torch.stack(question_atts,0)
+
+ answer_output = self.text_decoder(answer.input_ids,
+ attention_mask = answer.attention_mask,
+ encoder_hidden_states = question_states,
+ encoder_attention_mask = question_atts,
+ labels = answer_targets,
+ return_dict = True,
+ reduction = 'none',
+ )
+
+ loss = weights * answer_output.loss
+ loss = loss.sum()/image.size(0)
+
+ return loss
+
+
+ else:
+ question_output = self.text_encoder(question.input_ids,
+ attention_mask = question.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True)
+
+ if inference=='generate':
+ num_beams = 3
+ question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
+ question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
+ model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
+
+ bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device)
+
+ outputs = self.text_decoder.generate(input_ids=bos_ids,
+ max_length=10,
+ min_length=1,
+ num_beams=num_beams,
+ eos_token_id=self.tokenizer.sep_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ **model_kwargs)
+
+ answers = []
+ for output in outputs:
+ answer = self.tokenizer.decode(output, skip_special_tokens=True)
+ answers.append(answer)
+ return answers
+
+ elif inference=='rank':
+ max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
+ answer.input_ids, answer.attention_mask, k_test)
+ return max_ids
+
+
+
+ def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
+
+ num_ques = question_states.size(0)
+ start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
+
+ start_output = self.text_decoder(start_ids,
+ encoder_hidden_states = question_states,
+ encoder_attention_mask = question_atts,
+ return_dict = True,
+ reduction = 'none')
+ logits = start_output.logits[:,0,:] # first token's logit
+
+ # topk_probs: top-k probability
+ # topk_ids: [num_question, k]
+ answer_first_token = answer_ids[:,1]
+ prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
+ topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
+
+ # answer input: [num_question*k, answer_len]
+ input_ids = []
+ input_atts = []
+ for b, topk_id in enumerate(topk_ids):
+ input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
+ input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
+ input_ids = torch.cat(input_ids,dim=0)
+ input_atts = torch.cat(input_atts,dim=0)
+
+ targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
+
+ # repeat encoder's output for top-k answers
+ question_states = tile(question_states, 0, k)
+ question_atts = tile(question_atts, 0, k)
+
+ output = self.text_decoder(input_ids,
+ attention_mask = input_atts,
+ encoder_hidden_states = question_states,
+ encoder_attention_mask = question_atts,
+ labels = targets_ids,
+ return_dict = True,
+ reduction = 'none')
+
+ log_probs_sum = -output.loss
+ log_probs_sum = log_probs_sum.view(num_ques,k)
+
+ max_topk_ids = log_probs_sum.argmax(dim=1)
+ max_ids = topk_ids[max_topk_ids>=0,max_topk_ids]
+
+ return max_ids
+
+
+def blip_vqa(pretrained='',**kwargs):
+ model = BLIP_VQA(**kwargs)
+ if pretrained:
+ model,msg = load_checkpoint(model,pretrained)
+# assert(len(msg.missing_keys)==0)
+ return model
+
+
+def tile(x, dim, n_tile):
+ init_dim = x.size(dim)
+ repeat_idx = [1] * x.dim()
+ repeat_idx[dim] = n_tile
+ x = x.repeat(*(repeat_idx))
+ order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
+ return torch.index_select(x, dim, order_index.to(x.device))
+
+
\ No newline at end of file
diff --git a/repositories/BLIP/models/med.py b/repositories/BLIP/models/med.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b00a35450b736180a805d4f4664b4fb95aeba01
--- /dev/null
+++ b/repositories/BLIP/models/med.py
@@ -0,0 +1,955 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+'''
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ embeddings = inputs_embeds
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False):
+ super().__init__()
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if self.config.add_cross_attention:
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ mode=None,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+
+ if mode=='multimodal':
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ mode='multimodal',
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ mode=mode,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ mode=mode,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """ Initialize the weights """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = torch.cat(
+ [
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ elif encoder_embeds is not None:
+ input_shape = encoder_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = encoder_embeds.device
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
+ device, is_decoder)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+ else:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if encoder_embeds is None:
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ else:
+ embedding_output = encoder_embeds
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ mode=mode,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ return_logits=False,
+ is_decoder=True,
+ reduction='mean',
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ Returns:
+ Example::
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+ >>> import torch
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> prediction_logits = outputs.logits
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ is_decoder=is_decoder,
+ mode=mode,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ if return_logits:
+ return prediction_scores[:, :-1, :].contiguous()
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ if reduction=='none':
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+ "is_decoder": True,
+ }
+
+ def _reorder_cache(self, past, beam_idx):
+ reordered_past = ()
+ for layer_past in past:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
diff --git a/repositories/BLIP/models/nlvr_encoder.py b/repositories/BLIP/models/nlvr_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1946bb4a300f75afa4848f6622839445903c34a9
--- /dev/null
+++ b/repositories/BLIP/models/nlvr_encoder.py
@@ -0,0 +1,843 @@
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+ ModelOutput,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.config = config
+
+ def forward(
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ embeddings = inputs_embeds
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config, is_cross_attention):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ if is_cross_attention:
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
+ else:
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ self.save_attention = False
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if is_cross_attention and self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config, twin=False, merge=False):
+ super().__init__()
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ if twin:
+ self.dense0 = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dense1 = nn.Linear(config.hidden_size, config.hidden_size)
+ else:
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if merge:
+ self.act = ACT2FN[config.hidden_act]
+ self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size)
+ self.merge = True
+ else:
+ self.merge = False
+
+ def forward(self, hidden_states, input_tensor):
+ if type(hidden_states) == list:
+ hidden_states0 = self.dense0(hidden_states[0])
+ hidden_states1 = self.dense1(hidden_states[1])
+ if self.merge:
+ #hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1)))
+ hidden_states = self.merge_layer(torch.cat([hidden_states0,hidden_states1],dim=-1))
+ else:
+ hidden_states = (hidden_states0+hidden_states1)/2
+ else:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config, is_cross_attention=False, layer_num=-1):
+ super().__init__()
+ if is_cross_attention:
+ self.self0 = BertSelfAttention(config, is_cross_attention)
+ self.self1 = BertSelfAttention(config, is_cross_attention)
+ else:
+ self.self = BertSelfAttention(config, is_cross_attention)
+ self.output = BertSelfOutput(config, twin=is_cross_attention, merge=(is_cross_attention and layer_num>=6))
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ if type(encoder_hidden_states)==list:
+ self_outputs0 = self.self0(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states[0],
+ encoder_attention_mask[0],
+ past_key_value,
+ output_attentions,
+ )
+ self_outputs1 = self.self1(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states[1],
+ encoder_attention_mask[1],
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output([self_outputs0[0],self_outputs1[0]], hidden_states)
+
+ outputs = (attention_output,) + self_outputs0[1:] # add attentions if we output them
+ else:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config, layer_num):
+ super().__init__()
+ self.config = config
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = BertAttention(config)
+ self.layer_num = layer_num
+ if self.config.add_cross_attention:
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention, layer_num=layer_num)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ mode=None,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+
+ if mode=='multimodal':
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ mode='multimodal',
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layer[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ if use_cache:
+ logger.warn(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ mode=mode,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ mode=mode,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = BertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = BertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = BertConfig
+ base_model_prefix = "bert"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def _init_weights(self, module):
+ """ Initialize the weights """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+ """
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+ all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+ input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = BertEmbeddings(config)
+
+ self.encoder = BertEncoder(config)
+
+ self.pooler = BertPooler(config) if add_pooling_layer else None
+
+ self.init_weights()
+
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (:obj:`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (:obj:`Tuple[int]`):
+ The shape of the input to the model.
+ device: (:obj:`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if is_decoder:
+ batch_size, seq_length = input_shape
+
+ seq_ids = torch.arange(seq_length, device=device)
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.to(attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = torch.cat(
+ [
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ is_decoder=False,
+ mode='multimodal',
+ ):
+ r"""
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+ use_cache (:obj:`bool`, `optional`):
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+ decoding (see :obj:`past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = inputs_embeds.device
+ elif encoder_embeds is not None:
+ input_shape = encoder_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = encoder_embeds.device
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
+ device, is_decoder)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+ else:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if encoder_embeds is None:
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ else:
+ embedding_output = encoder_embeds
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ mode=mode,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
diff --git a/repositories/BLIP/models/vit.py b/repositories/BLIP/models/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..cec3d8e08ed4451d65392feb2e9f4848d1ef3899
--- /dev/null
+++ b/repositories/BLIP/models/vit.py
@@ -0,0 +1,305 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on timm code base
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+
+from timm.models.vision_transformer import _cfg, PatchEmbed
+from timm.models.registry import register_model
+from timm.models.layers import trunc_normal_, DropPath
+from timm.models.helpers import named_apply, adapt_input_conv
+
+from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.attn_gradients = None
+ self.attention_map = None
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def forward(self, x, register_hook=False):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ if register_hook:
+ self.save_attention_map(attn)
+ attn.register_hook(self.save_attn_gradients)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if use_grad_checkpointing:
+ self.attn = checkpoint_wrapper(self.attn)
+ self.mlp = checkpoint_wrapper(self.mlp)
+
+ def forward(self, x, register_hook=False):
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
+ https://arxiv.org/abs/2010.11929
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
+ use_grad_checkpointing=False, ckpt_layer=0):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ norm_layer: (nn.Module): normalization layer
+ """
+ super().__init__()
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
+ )
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def forward(self, x, register_blk=-1):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + self.pos_embed[:,:x.size(1),:]
+ x = self.pos_drop(x)
+
+ for i,blk in enumerate(self.blocks):
+ x = blk(x, register_blk==i)
+ x = self.norm(x)
+
+ return x
+
+ @torch.jit.ignore()
+ def load_pretrained(self, checkpoint_path, prefix=''):
+ _load_weights(self, checkpoint_path, prefix)
+
+
+@torch.no_grad()
+def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
+ """
+ import numpy as np
+
+ def _n2p(w, t=True):
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
+ w = w.flatten()
+ if t:
+ if w.ndim == 4:
+ w = w.transpose([3, 2, 0, 1])
+ elif w.ndim == 3:
+ w = w.transpose([2, 0, 1])
+ elif w.ndim == 2:
+ w = w.transpose([1, 0])
+ return torch.from_numpy(w)
+
+ w = np.load(checkpoint_path)
+ if not prefix and 'opt/target/embedding/kernel' in w:
+ prefix = 'opt/target/'
+
+ if hasattr(model.patch_embed, 'backbone'):
+ # hybrid
+ backbone = model.patch_embed.backbone
+ stem_only = not hasattr(backbone, 'stem')
+ stem = backbone if stem_only else backbone.stem
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
+ if not stem_only:
+ for i, stage in enumerate(backbone.stages):
+ for j, block in enumerate(stage.blocks):
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
+ for r in range(3):
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
+ if block.downsample is not None:
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
+ else:
+ embed_conv_w = adapt_input_conv(
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
+ if pos_embed_w.shape != model.pos_embed.shape:
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
+ model.pos_embed.copy_(pos_embed_w)
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
+# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
+# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
+# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
+# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
+# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
+# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
+ for i, block in enumerate(model.blocks.children()):
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
+ block.attn.qkv.weight.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
+ block.attn.qkv.bias.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
+ for r in range(2):
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
+
+
+def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
+ # interpolate position embedding
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = visual_encoder.patch_embed.num_patches
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+
+ if orig_size!=new_size:
+ # class_token and dist_token are kept unchanged
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
+
+ return new_pos_embed
+ else:
+ return pos_embed_checkpoint
\ No newline at end of file
diff --git a/repositories/BLIP/predict.py b/repositories/BLIP/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..35426cadcbb3bf8c3d8cb9c910511c154e451f4e
--- /dev/null
+++ b/repositories/BLIP/predict.py
@@ -0,0 +1,98 @@
+"""
+Download the weights in ./checkpoints beforehand for fast inference
+wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth
+wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth
+wget https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth
+"""
+
+from pathlib import Path
+
+from PIL import Image
+import torch
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+import cog
+
+from models.blip import blip_decoder
+from models.blip_vqa import blip_vqa
+from models.blip_itm import blip_itm
+
+
+class Predictor(cog.Predictor):
+ def setup(self):
+ self.device = "cuda:0"
+
+ self.models = {
+ 'image_captioning': blip_decoder(pretrained='checkpoints/model*_base_caption.pth',
+ image_size=384, vit='base'),
+ 'visual_question_answering': blip_vqa(pretrained='checkpoints/model*_vqa.pth',
+ image_size=480, vit='base'),
+ 'image_text_matching': blip_itm(pretrained='checkpoints/model_base_retrieval_coco.pth',
+ image_size=384, vit='base')
+ }
+
+ @cog.input(
+ "image",
+ type=Path,
+ help="input image",
+ )
+ @cog.input(
+ "task",
+ type=str,
+ default='image_captioning',
+ options=['image_captioning', 'visual_question_answering', 'image_text_matching'],
+ help="Choose a task.",
+ )
+ @cog.input(
+ "question",
+ type=str,
+ default=None,
+ help="Type question for the input image for visual question answering task.",
+ )
+ @cog.input(
+ "caption",
+ type=str,
+ default=None,
+ help="Type caption for the input image for image text matching task.",
+ )
+ def predict(self, image, task, question, caption):
+ if task == 'visual_question_answering':
+ assert question is not None, 'Please type a question for visual question answering task.'
+ if task == 'image_text_matching':
+ assert caption is not None, 'Please type a caption for mage text matching task.'
+
+ im = load_image(image, image_size=480 if task == 'visual_question_answering' else 384, device=self.device)
+ model = self.models[task]
+ model.eval()
+ model = model.to(self.device)
+
+ if task == 'image_captioning':
+ with torch.no_grad():
+ caption = model.generate(im, sample=False, num_beams=3, max_length=20, min_length=5)
+ return 'Caption: ' + caption[0]
+
+ if task == 'visual_question_answering':
+ with torch.no_grad():
+ answer = model(im, question, train=False, inference='generate')
+ return 'Answer: ' + answer[0]
+
+ # image_text_matching
+ itm_output = model(im, caption, match_head='itm')
+ itm_score = torch.nn.functional.softmax(itm_output, dim=1)[:, 1]
+ itc_score = model(im, caption, match_head='itc')
+ return f'The image and text is matched with a probability of {itm_score.item():.4f}.\n' \
+ f'The image feature and text feature has a cosine similarity of {itc_score.item():.4f}.'
+
+
+def load_image(image, image_size, device):
+ raw_image = Image.open(str(image)).convert('RGB')
+
+ w, h = raw_image.size
+
+ transform = transforms.Compose([
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+ ])
+ image = transform(raw_image).unsqueeze(0).to(device)
+ return image
diff --git a/repositories/BLIP/pretrain.py b/repositories/BLIP/pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9490ec8eb8ff5f074b5772ada55cd27ec673a12
--- /dev/null
+++ b/repositories/BLIP/pretrain.py
@@ -0,0 +1,173 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+
+from models.blip_pretrain import blip_pretrain
+import utils
+from utils import warmup_lr_schedule, step_lr_schedule
+from data import create_dataset, create_sampler, create_loader
+
+def train(model, data_loader, optimizer, epoch, device, config):
+ # train
+ model.train()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
+ metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
+ metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
+ metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
+
+ header = 'Train Epoch: [{}]'.format(epoch)
+ print_freq = 50
+
+ if config['laion_path']:
+ data_loader.dataset.reload_laion(epoch)
+
+ data_loader.sampler.set_epoch(epoch)
+
+ for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+
+ if epoch==0:
+ warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr'])
+
+ optimizer.zero_grad()
+
+ image = image.to(device,non_blocking=True)
+
+ # ramp up alpha in the first 2 epochs
+ alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader)))
+
+ loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha)
+ loss = loss_ita + loss_itm + loss_lm
+
+ loss.backward()
+ optimizer.step()
+
+ metric_logger.update(loss_ita=loss_ita.item())
+ metric_logger.update(loss_itm=loss_itm.item())
+ metric_logger.update(loss_lm=loss_lm.item())
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger.global_avg())
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating dataset")
+ datasets = [create_dataset('pretrain', config, min_scale=0.2)]
+ print('number of training samples: %d'%len(datasets[0]))
+
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ samplers = create_sampler(datasets, [True], num_tasks, global_rank)
+
+ data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0]
+
+ #### Model ####
+ print("Creating model")
+ model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'],
+ vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size'])
+
+ model = model.to(device)
+
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
+
+ start_epoch = 0
+ if args.checkpoint:
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
+ state_dict = checkpoint['model']
+ model.load_state_dict(state_dict)
+
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ start_epoch = checkpoint['epoch']+1
+ print('resume checkpoint from %s'%args.checkpoint)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ print("Start training")
+ start_time = time.time()
+ for epoch in range(start_epoch, config['max_epoch']):
+
+ step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate'])
+
+ train_stats = train(model, data_loader, optimizer, epoch, device, config)
+ if utils.is_main_process():
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ 'epoch': epoch,
+ }
+ save_obj = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'config': config,
+ 'epoch': epoch,
+ }
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
+
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ dist.barrier()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/pretrain.yaml')
+ parser.add_argument('--output_dir', default='output/Pretrain')
+ parser.add_argument('--checkpoint', default='')
+ parser.add_argument('--evaluate', action='store_true')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/repositories/BLIP/requirements.txt b/repositories/BLIP/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d897bc6a08712f4beb2f78ca2592dcbe06a3e2db
--- /dev/null
+++ b/repositories/BLIP/requirements.txt
@@ -0,0 +1,4 @@
+timm==0.4.12
+transformers==4.15.0
+fairscale==0.4.4
+pycocoevalcap
diff --git a/repositories/BLIP/train_caption.py b/repositories/BLIP/train_caption.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c639ac646b9a1b8074b6e9c2343b961de76db05
--- /dev/null
+++ b/repositories/BLIP/train_caption.py
@@ -0,0 +1,206 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+
+from models.blip import blip_decoder
+import utils
+from utils import cosine_lr_schedule
+from data import create_dataset, create_sampler, create_loader
+from data.utils import save_result, coco_caption_eval
+
+def train(model, data_loader, optimizer, epoch, device):
+ # train
+ model.train()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
+ header = 'Train Caption Epoch: [{}]'.format(epoch)
+ print_freq = 50
+
+ for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ image = image.to(device)
+
+ loss = model(image, caption)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ metric_logger.update(loss=loss.item())
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger.global_avg())
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate(model, data_loader, device, config):
+ # evaluate
+ model.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Caption generation:'
+ print_freq = 10
+
+ result = []
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
+
+ image = image.to(device)
+
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
+ min_length=config['min_length'])
+
+ for caption, img_id in zip(captions, image_id):
+ result.append({"image_id": img_id.item(), "caption": caption})
+
+ return result
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating captioning dataset")
+ train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config)
+
+ if args.distributed:
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank)
+ else:
+ samplers = [None, None, None]
+
+ train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
+ batch_size=[config['batch_size']]*3,num_workers=[4,4,4],
+ is_trains=[True, False, False], collate_fns=[None,None,None])
+
+ #### Model ####
+ print("Creating model")
+ model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
+ vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
+ prompt=config['prompt'])
+
+ model = model.to(device)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
+
+ best = 0
+ best_epoch = 0
+
+ print("Start training")
+ start_time = time.time()
+ for epoch in range(0, config['max_epoch']):
+ if not args.evaluate:
+ if args.distributed:
+ train_loader.sampler.set_epoch(epoch)
+
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
+
+ train_stats = train(model, train_loader, optimizer, epoch, device)
+
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
+ val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id')
+
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
+ test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id')
+
+ if utils.is_main_process():
+ coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val')
+ coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test')
+
+ if args.evaluate:
+ log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()},
+ **{f'test_{k}': v for k, v in coco_test.eval.items()},
+ }
+ with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+ else:
+ save_obj = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'config': config,
+ 'epoch': epoch,
+ }
+
+ if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best:
+ best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4']
+ best_epoch = epoch
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
+
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'val_{k}': v for k, v in coco_val.eval.items()},
+ **{f'test_{k}': v for k, v in coco_test.eval.items()},
+ 'epoch': epoch,
+ 'best_epoch': best_epoch,
+ }
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ if args.evaluate:
+ break
+ dist.barrier()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/caption_coco.yaml')
+ parser.add_argument('--output_dir', default='output/Caption_coco')
+ parser.add_argument('--evaluate', action='store_true')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ args.result_dir = os.path.join(args.output_dir, 'result')
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/repositories/BLIP/train_nlvr.py b/repositories/BLIP/train_nlvr.py
new file mode 100644
index 0000000000000000000000000000000000000000..84b247bda2334c1fd894b6c11d33ef48c8e7df28
--- /dev/null
+++ b/repositories/BLIP/train_nlvr.py
@@ -0,0 +1,213 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+import json
+import pickle
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+
+from models.blip_nlvr import blip_nlvr
+
+import utils
+from utils import cosine_lr_schedule, warmup_lr_schedule
+from data import create_dataset, create_sampler, create_loader
+
+def train(model, data_loader, optimizer, epoch, device, config):
+ # train
+ model.train()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
+
+ header = 'Train Epoch: [{}]'.format(epoch)
+ print_freq = 50
+ step_size = 10
+
+ for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+
+ images = torch.cat([image0, image1], dim=0)
+ images, targets = images.to(device), targets.to(device)
+
+ loss = model(images, text, targets=targets, train=True)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+ metric_logger.update(loss=loss.item())
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger.global_avg())
+ return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate(model, data_loader, device, config):
+ # test
+ model.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+
+ header = 'Evaluation:'
+ print_freq = 50
+
+ for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header):
+ images = torch.cat([image0, image1], dim=0)
+ images, targets = images.to(device), targets.to(device)
+
+ prediction = model(images, text, targets=targets, train=False)
+
+ _, pred_class = prediction.max(1)
+ accuracy = (targets==pred_class).sum() / targets.size(0)
+
+ metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0))
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+
+ print("Averaged stats:", metric_logger.global_avg())
+ return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
+
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating dataset")
+ datasets = create_dataset('nlvr', config)
+
+ if args.distributed:
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank)
+ else:
+ samplers = [None, None, None]
+
+ batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']]
+ train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size,
+ num_workers=[4,4,4],is_trains=[True,False,False],
+ collate_fns=[None,None,None])
+
+ #### Model ####
+ print("Creating model")
+ model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'],
+ vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
+
+ model = model.to(device)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
+
+ print("Start training")
+ start_time = time.time()
+ best = 0
+ best_epoch = 0
+
+ for epoch in range(0, config['max_epoch']):
+ if not args.evaluate:
+ if args.distributed:
+ train_loader.sampler.set_epoch(epoch)
+
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
+
+ train_stats = train(model, train_loader, optimizer, epoch, device, config)
+
+ val_stats = evaluate(model, val_loader, device, config)
+ test_stats = evaluate(model, test_loader, device, config)
+
+ if utils.is_main_process():
+ if args.evaluate:
+ log_stats = {**{f'val_{k}': v for k, v in val_stats.items()},
+ **{f'test_{k}': v for k, v in test_stats.items()},
+ }
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ else:
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'val_{k}': v for k, v in val_stats.items()},
+ **{f'test_{k}': v for k, v in test_stats.items()},
+ 'epoch': epoch,
+ }
+
+ if float(val_stats['acc'])>best:
+ save_obj = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'config': config,
+ 'epoch': epoch,
+ }
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
+ best = float(val_stats['acc'])
+ best_epoch = epoch
+
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+ if args.evaluate:
+ break
+
+ dist.barrier()
+
+ if utils.is_main_process():
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write("best epoch: %d"%best_epoch)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/nlvr.yaml')
+ parser.add_argument('--output_dir', default='output/NLVR')
+ parser.add_argument('--evaluate', action='store_true')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/repositories/BLIP/train_retrieval.py b/repositories/BLIP/train_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..574f03382cc8197b97971a11ae54b632bcfe6655
--- /dev/null
+++ b/repositories/BLIP/train_retrieval.py
@@ -0,0 +1,345 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from torch.utils.data import DataLoader
+
+from models.blip_retrieval import blip_retrieval
+import utils
+from utils import cosine_lr_schedule
+from data import create_dataset, create_sampler, create_loader
+
+
+def train(model, data_loader, optimizer, epoch, device, config):
+ # train
+ model.train()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
+ metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
+ header = 'Train Epoch: [{}]'.format(epoch)
+ print_freq = 50
+
+ for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ image = image.to(device,non_blocking=True)
+ idx = idx.to(device,non_blocking=True)
+
+ if epoch>0:
+ alpha = config['alpha']
+ else:
+ alpha = config['alpha']*min(1,i/len(data_loader))
+
+ loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx)
+ loss = loss_ita + loss_itm
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ metric_logger.update(loss_itm=loss_itm.item())
+ metric_logger.update(loss_ita=loss_ita.item())
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger.global_avg())
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluation(model, data_loader, device, config):
+ # test
+ model.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Evaluation:'
+
+ print('Computing features for evaluation...')
+ start_time = time.time()
+
+ texts = data_loader.dataset.text
+ num_text = len(texts)
+ text_bs = 256
+ text_ids = []
+ text_embeds = []
+ text_atts = []
+ for i in range(0, num_text, text_bs):
+ text = texts[i: min(num_text, i+text_bs)]
+ text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
+ text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
+ text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
+ text_embeds.append(text_embed)
+ text_ids.append(text_input.input_ids)
+ text_atts.append(text_input.attention_mask)
+
+ text_embeds = torch.cat(text_embeds,dim=0)
+ text_ids = torch.cat(text_ids,dim=0)
+ text_atts = torch.cat(text_atts,dim=0)
+ text_ids[:,0] = model.tokenizer.enc_token_id
+
+ image_feats = []
+ image_embeds = []
+ for image, img_id in data_loader:
+ image = image.to(device)
+ image_feat = model.visual_encoder(image)
+ image_embed = model.vision_proj(image_feat[:,0,:])
+ image_embed = F.normalize(image_embed,dim=-1)
+
+ image_feats.append(image_feat.cpu())
+ image_embeds.append(image_embed)
+
+ image_feats = torch.cat(image_feats,dim=0)
+ image_embeds = torch.cat(image_embeds,dim=0)
+
+ sims_matrix = image_embeds @ text_embeds.t()
+ score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device)
+
+ num_tasks = utils.get_world_size()
+ rank = utils.get_rank()
+ step = sims_matrix.size(0)//num_tasks + 1
+ start = rank*step
+ end = min(sims_matrix.size(0),start+step)
+
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
+
+ encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device)
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
+ output = model.text_encoder(text_ids[topk_idx],
+ attention_mask = text_atts[topk_idx],
+ encoder_hidden_states = encoder_output,
+ encoder_attention_mask = encoder_att,
+ return_dict = True,
+ )
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
+ score_matrix_i2t[start+i,topk_idx] = score + topk_sim
+
+ sims_matrix = sims_matrix.t()
+ score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device)
+
+ step = sims_matrix.size(0)//num_tasks + 1
+ start = rank*step
+ end = min(sims_matrix.size(0),start+step)
+
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
+
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
+ encoder_output = image_feats[topk_idx].to(device)
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
+ output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
+ attention_mask = text_atts[start+i].repeat(config['k_test'],1),
+ encoder_hidden_states = encoder_output,
+ encoder_attention_mask = encoder_att,
+ return_dict = True,
+ )
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
+ score_matrix_t2i[start+i,topk_idx] = score + topk_sim
+
+ if args.distributed:
+ dist.barrier()
+ torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM)
+ torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Evaluation time {}'.format(total_time_str))
+
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
+
+
+
+@torch.no_grad()
+def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
+
+ #Images->Text
+ ranks = np.zeros(scores_i2t.shape[0])
+ for index,score in enumerate(scores_i2t):
+ inds = np.argsort(score)[::-1]
+ # Score
+ rank = 1e20
+ for i in img2txt[index]:
+ tmp = np.where(inds == i)[0][0]
+ if tmp < rank:
+ rank = tmp
+ ranks[index] = rank
+
+ # Compute metrics
+ tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
+ tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
+ tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
+
+ #Text->Images
+ ranks = np.zeros(scores_t2i.shape[0])
+
+ for index,score in enumerate(scores_t2i):
+ inds = np.argsort(score)[::-1]
+ ranks[index] = np.where(inds == txt2img[index])[0][0]
+
+ # Compute metrics
+ ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
+ ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
+ ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
+
+ tr_mean = (tr1 + tr5 + tr10) / 3
+ ir_mean = (ir1 + ir5 + ir10) / 3
+ r_mean = (tr_mean + ir_mean) / 2
+
+ eval_result = {'txt_r1': tr1,
+ 'txt_r5': tr5,
+ 'txt_r10': tr10,
+ 'txt_r_mean': tr_mean,
+ 'img_r1': ir1,
+ 'img_r5': ir5,
+ 'img_r10': ir10,
+ 'img_r_mean': ir_mean,
+ 'r_mean': r_mean}
+ return eval_result
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating retrieval dataset")
+ train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config)
+
+ if args.distributed:
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
+ else:
+ samplers = [None, None, None]
+
+ train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
+ batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
+ num_workers=[4,4,4],
+ is_trains=[True, False, False],
+ collate_fns=[None,None,None])
+
+
+ #### Model ####
+ print("Creating model")
+ model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
+ vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
+ queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])
+
+ model = model.to(device)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
+
+ best = 0
+ best_epoch = 0
+
+ print("Start training")
+ start_time = time.time()
+
+ for epoch in range(0, config['max_epoch']):
+ if not args.evaluate:
+ if args.distributed:
+ train_loader.sampler.set_epoch(epoch)
+
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
+
+ train_stats = train(model, train_loader, optimizer, epoch, device, config)
+
+ score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config)
+ score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config)
+
+ if utils.is_main_process():
+
+ val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)
+ print(val_result)
+
+ if val_result['r_mean']>best:
+ save_obj = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'config': config,
+ 'epoch': epoch,
+ }
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
+ best = val_result['r_mean']
+ best_epoch = epoch
+
+ test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)
+ print(test_result)
+
+ if args.evaluate:
+ log_stats = {**{f'val_{k}': v for k, v in val_result.items()},
+ **{f'test_{k}': v for k, v in test_result.items()},
+ }
+ with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+ else:
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'val_{k}': v for k, v in val_result.items()},
+ **{f'test_{k}': v for k, v in test_result.items()},
+ 'epoch': epoch,
+ 'best_epoch': best_epoch,
+ }
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ if args.evaluate:
+ break
+
+ dist.barrier()
+ torch.cuda.empty_cache()
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/retrieval_flickr.yaml')
+ parser.add_argument('--output_dir', default='output/Retrieval_flickr')
+ parser.add_argument('--evaluate', action='store_true')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/repositories/BLIP/train_vqa.py b/repositories/BLIP/train_vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..89eb7490862e517cc660f842396033c21d441a20
--- /dev/null
+++ b/repositories/BLIP/train_vqa.py
@@ -0,0 +1,202 @@
+'''
+ * Copyright (c) 2022, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+'''
+import argparse
+import os
+import ruamel_yaml as yaml
+import numpy as np
+import random
+import time
+import datetime
+import json
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+
+from models.blip_vqa import blip_vqa
+import utils
+from utils import cosine_lr_schedule
+from data import create_dataset, create_sampler, create_loader
+from data.vqa_dataset import vqa_collate_fn
+from data.utils import save_result
+
+
+def train(model, data_loader, optimizer, epoch, device):
+ # train
+ model.train()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
+
+ header = 'Train Epoch: [{}]'.format(epoch)
+ print_freq = 50
+
+ for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True)
+
+ loss = model(image, question, answer, train=True, n=n, weights=weights)
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ metric_logger.update(loss=loss.item())
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger.global_avg())
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluation(model, data_loader, device, config) :
+ # test
+ model.eval()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Generate VQA test result:'
+ print_freq = 50
+
+ result = []
+
+ if config['inference']=='rank':
+ answer_list = data_loader.dataset.answer_list
+ answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device)
+ answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id
+
+ for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ image = image.to(device,non_blocking=True)
+
+ if config['inference']=='generate':
+ answers = model(image, question, train=False, inference='generate')
+
+ for answer, ques_id in zip(answers, question_id):
+ ques_id = int(ques_id.item())
+ result.append({"question_id":ques_id, "answer":answer})
+
+ elif config['inference']=='rank':
+ answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test'])
+
+ for ques_id, answer_id in zip(question_id, answer_ids):
+ result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]})
+
+ return result
+
+
+def main(args, config):
+ utils.init_distributed_mode(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ cudnn.benchmark = True
+
+ #### Dataset ####
+ print("Creating vqa datasets")
+ datasets = create_dataset('vqa', config)
+
+ if args.distributed:
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
+ else:
+ samplers = [None, None]
+
+ train_loader, test_loader = create_loader(datasets,samplers,
+ batch_size=[config['batch_size_train'],config['batch_size_test']],
+ num_workers=[4,4],is_trains=[True, False],
+ collate_fns=[vqa_collate_fn,None])
+ #### Model ####
+ print("Creating model")
+ model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'],
+ vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
+
+ model = model.to(device)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
+
+ best = 0
+ best_epoch = 0
+
+ print("Start training")
+ start_time = time.time()
+ for epoch in range(0, config['max_epoch']):
+ if not args.evaluate:
+ if args.distributed:
+ train_loader.sampler.set_epoch(epoch)
+
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
+
+ train_stats = train(model, train_loader, optimizer, epoch, device)
+
+ else:
+ break
+
+ if utils.is_main_process():
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ 'epoch': epoch,
+ }
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ save_obj = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'config': config,
+ 'epoch': epoch,
+ }
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
+
+ dist.barrier()
+
+ vqa_result = evaluation(model_without_ddp, test_loader, device, config)
+ result_file = save_result(vqa_result, args.result_dir, 'vqa_result')
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', default='./configs/vqa.yaml')
+ parser.add_argument('--output_dir', default='output/VQA')
+ parser.add_argument('--evaluate', action='store_true')
+ parser.add_argument('--device', default='cuda')
+ parser.add_argument('--seed', default=42, type=int)
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--distributed', default=True, type=bool)
+ args = parser.parse_args()
+
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
+
+ args.result_dir = os.path.join(args.output_dir, 'result')
+
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
+
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
+
+ main(args, config)
\ No newline at end of file
diff --git a/repositories/BLIP/transform/randaugment.py b/repositories/BLIP/transform/randaugment.py
new file mode 100644
index 0000000000000000000000000000000000000000..094d9f4cacc93146d2bab7311d9dc04feb07032c
--- /dev/null
+++ b/repositories/BLIP/transform/randaugment.py
@@ -0,0 +1,340 @@
+import cv2
+import numpy as np
+
+
+## aug functions
+def identity_func(img):
+ return img
+
+
+def autocontrast_func(img, cutoff=0):
+ '''
+ same output as PIL.ImageOps.autocontrast
+ '''
+ n_bins = 256
+
+ def tune_channel(ch):
+ n = ch.size
+ cut = cutoff * n // 100
+ if cut == 0:
+ high, low = ch.max(), ch.min()
+ else:
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+ low = np.argwhere(np.cumsum(hist) > cut)
+ low = 0 if low.shape[0] == 0 else low[0]
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
+ if high <= low:
+ table = np.arange(n_bins)
+ else:
+ scale = (n_bins - 1) / (high - low)
+ offset = -low * scale
+ table = np.arange(n_bins) * scale + offset
+ table[table < 0] = 0
+ table[table > n_bins - 1] = n_bins - 1
+ table = table.clip(0, 255).astype(np.uint8)
+ return table[ch]
+
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
+ out = cv2.merge(channels)
+ return out
+
+
+def equalize_func(img):
+ '''
+ same output as PIL.ImageOps.equalize
+ PIL's implementation is different from cv2.equalize
+ '''
+ n_bins = 256
+
+ def tune_channel(ch):
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+ non_zero_hist = hist[hist != 0].reshape(-1)
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
+ if step == 0: return ch
+ n = np.empty_like(hist)
+ n[0] = step // 2
+ n[1:] = hist[:-1]
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
+ return table[ch]
+
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
+ out = cv2.merge(channels)
+ return out
+
+
+def rotate_func(img, degree, fill=(0, 0, 0)):
+ '''
+ like PIL, rotate by degree, not radians
+ '''
+ H, W = img.shape[0], img.shape[1]
+ center = W / 2, H / 2
+ M = cv2.getRotationMatrix2D(center, degree, 1)
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
+ return out
+
+
+def solarize_func(img, thresh=128):
+ '''
+ same output as PIL.ImageOps.posterize
+ '''
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
+ table = table.clip(0, 255).astype(np.uint8)
+ out = table[img]
+ return out
+
+
+def color_func(img, factor):
+ '''
+ same output as PIL.ImageEnhance.Color
+ '''
+ ## implementation according to PIL definition, quite slow
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
+ # out = blend(degenerate, img, factor)
+ # M = (
+ # np.eye(3) * factor
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
+ # )[np.newaxis, np.newaxis, :]
+ M = (
+ np.float32([
+ [0.886, -0.114, -0.114],
+ [-0.587, 0.413, -0.587],
+ [-0.299, -0.299, 0.701]]) * factor
+ + np.float32([[0.114], [0.587], [0.299]])
+ )
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
+ return out
+
+
+def contrast_func(img, factor):
+ """
+ same output as PIL.ImageEnhance.Contrast
+ """
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
+ table = np.array([(
+ el - mean) * factor + mean
+ for el in range(256)
+ ]).clip(0, 255).astype(np.uint8)
+ out = table[img]
+ return out
+
+
+def brightness_func(img, factor):
+ '''
+ same output as PIL.ImageEnhance.Contrast
+ '''
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
+ out = table[img]
+ return out
+
+
+def sharpness_func(img, factor):
+ '''
+ The differences the this result and PIL are all on the 4 boundaries, the center
+ areas are same
+ '''
+ kernel = np.ones((3, 3), dtype=np.float32)
+ kernel[1][1] = 5
+ kernel /= 13
+ degenerate = cv2.filter2D(img, -1, kernel)
+ if factor == 0.0:
+ out = degenerate
+ elif factor == 1.0:
+ out = img
+ else:
+ out = img.astype(np.float32)
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
+ out = out.astype(np.uint8)
+ return out
+
+
+def shear_x_func(img, factor, fill=(0, 0, 0)):
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
+ return out
+
+
+def translate_x_func(img, offset, fill=(0, 0, 0)):
+ '''
+ same output as PIL.Image.transform
+ '''
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
+ return out
+
+
+def translate_y_func(img, offset, fill=(0, 0, 0)):
+ '''
+ same output as PIL.Image.transform
+ '''
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
+ return out
+
+
+def posterize_func(img, bits):
+ '''
+ same output as PIL.ImageOps.posterize
+ '''
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
+ return out
+
+
+def shear_y_func(img, factor, fill=(0, 0, 0)):
+ H, W = img.shape[0], img.shape[1]
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
+ return out
+
+
+def cutout_func(img, pad_size, replace=(0, 0, 0)):
+ replace = np.array(replace, dtype=np.uint8)
+ H, W = img.shape[0], img.shape[1]
+ rh, rw = np.random.random(2)
+ pad_size = pad_size // 2
+ ch, cw = int(rh * H), int(rw * W)
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
+ out = img.copy()
+ out[x1:x2, y1:y2, :] = replace
+ return out
+
+
+### level to args
+def enhance_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
+ return level_to_args
+
+
+def shear_level_to_args(MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * 0.3
+ if np.random.random() > 0.5: level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * float(translate_const)
+ if np.random.random() > 0.5: level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * cutout_const)
+ return (level, replace_value)
+
+ return level_to_args
+
+
+def solarize_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * 256)
+ return (level, )
+ return level_to_args
+
+
+def none_level_to_args(level):
+ return ()
+
+
+def posterize_level_to_args(MAX_LEVEL):
+ def level_to_args(level):
+ level = int((level / MAX_LEVEL) * 4)
+ return (level, )
+ return level_to_args
+
+
+def rotate_level_to_args(MAX_LEVEL, replace_value):
+ def level_to_args(level):
+ level = (level / MAX_LEVEL) * 30
+ if np.random.random() < 0.5:
+ level = -level
+ return (level, replace_value)
+
+ return level_to_args
+
+
+func_dict = {
+ 'Identity': identity_func,
+ 'AutoContrast': autocontrast_func,
+ 'Equalize': equalize_func,
+ 'Rotate': rotate_func,
+ 'Solarize': solarize_func,
+ 'Color': color_func,
+ 'Contrast': contrast_func,
+ 'Brightness': brightness_func,
+ 'Sharpness': sharpness_func,
+ 'ShearX': shear_x_func,
+ 'TranslateX': translate_x_func,
+ 'TranslateY': translate_y_func,
+ 'Posterize': posterize_func,
+ 'ShearY': shear_y_func,
+}
+
+translate_const = 10
+MAX_LEVEL = 10
+replace_value = (128, 128, 128)
+arg_dict = {
+ 'Identity': none_level_to_args,
+ 'AutoContrast': none_level_to_args,
+ 'Equalize': none_level_to_args,
+ 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
+ 'Solarize': solarize_level_to_args(MAX_LEVEL),
+ 'Color': enhance_level_to_args(MAX_LEVEL),
+ 'Contrast': enhance_level_to_args(MAX_LEVEL),
+ 'Brightness': enhance_level_to_args(MAX_LEVEL),
+ 'Sharpness': enhance_level_to_args(MAX_LEVEL),
+ 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
+ 'TranslateX': translate_level_to_args(
+ translate_const, MAX_LEVEL, replace_value
+ ),
+ 'TranslateY': translate_level_to_args(
+ translate_const, MAX_LEVEL, replace_value
+ ),
+ 'Posterize': posterize_level_to_args(MAX_LEVEL),
+ 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
+}
+
+
+class RandomAugment(object):
+
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
+ self.N = N
+ self.M = M
+ self.isPIL = isPIL
+ if augs:
+ self.augs = augs
+ else:
+ self.augs = list(arg_dict.keys())
+
+ def get_random_ops(self):
+ sampled_ops = np.random.choice(self.augs, self.N)
+ return [(op, 0.5, self.M) for op in sampled_ops]
+
+ def __call__(self, img):
+ if self.isPIL:
+ img = np.array(img)
+ ops = self.get_random_ops()
+ for name, prob, level in ops:
+ if np.random.random() > prob:
+ continue
+ args = arg_dict[name](level)
+ img = func_dict[name](img, *args)
+ return img
+
+
+if __name__ == '__main__':
+ a = RandomAugment()
+ img = np.random.randn(32, 32, 3)
+ a(img)
\ No newline at end of file
diff --git a/repositories/BLIP/utils.py b/repositories/BLIP/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebe0e1dc2f5d200156d5dd1acc305a8b7b7b98da
--- /dev/null
+++ b/repositories/BLIP/utils.py
@@ -0,0 +1,278 @@
+import math
+def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
+ """Decay the learning rate"""
+ lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+
+def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
+ """Warmup the learning rate"""
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+
+def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
+ """Decay the learning rate"""
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = lr
+
+import numpy as np
+import io
+import os
+import time
+from collections import defaultdict, deque
+import datetime
+
+import torch
+import torch.distributed as dist
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def global_avg(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {:.4f}".format(name, meter.global_avg)
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def compute_acc(logits, label, reduction='mean'):
+ ret = (torch.argmax(logits, dim=1) == label).float()
+ if reduction == 'none':
+ return ret.detach()
+ elif reduction == 'mean':
+ return ret.mean().item()
+
+def compute_n_params(model, return_str=True):
+ tot = 0
+ for p in model.parameters():
+ w = 1
+ for x in p.shape:
+ w *= x
+ tot += w
+ if return_str:
+ if tot >= 1e6:
+ return '{:.1f}M'.format(tot / 1e6)
+ else:
+ return '{:.1f}K'.format(tot / 1e3)
+ else:
+ return tot
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}, word {}): {}'.format(
+ args.rank, args.world_size, args.dist_url), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
\ No newline at end of file
diff --git a/repositories/CodeFormer/.gitignore b/repositories/CodeFormer/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..15f690d1628f48e9d7d755a9472a44975fa06839
--- /dev/null
+++ b/repositories/CodeFormer/.gitignore
@@ -0,0 +1,128 @@
+.vscode
+
+# ignored files
+version.py
+
+# ignored files with suffix
+*.html
+# *.png
+# *.jpeg
+# *.jpg
+*.pt
+*.gif
+*.pth
+*.dat
+*.zip
+
+# template
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# project
+results/
+dlib/
+*_old*
+
diff --git a/repositories/CodeFormer/README.md b/repositories/CodeFormer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..bb5e1fe6d881c0862835828700ed2cfc8e999463
--- /dev/null
+++ b/repositories/CodeFormer/README.md
@@ -0,0 +1,123 @@
+
+
+
+
+## Towards Robust Blind Face Restoration with Codebook Lookup Transformer
+
+[Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
+
+
+ [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) ![visitors](https://visitor-badge.glitch.me/badge?page_id=sczhou/CodeFormer)
+
+[Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
+
+S-Lab, Nanyang Technological University
+
+
+
+
+:star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs:
+
+### Update
+
+- **2022.09.09**: Integrated to [Replicate](https://replicate.com/). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
+- **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement.
+- **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement.
+- **2022.08.07**: Integrate [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement.
+- **2022.07.29**: Integrate new face detectors of `['RetinaFace'(default), 'YOLOv5']`.
+- **2022.07.17**: Add Colab demo of CodeFormer.
+- **2022.07.16**: Release inference code for face restoration. :blush:
+- **2022.06.21**: This repo is created.
+
+### TODO
+- [ ] Add checkpoint for face inpainting
+- [ ] Add training code and config files
+- [x] ~~Add background image enhancement~~
+
+#### Face Restoration
+
+
+
+
+#### Face Color Enhancement and Restoration
+
+
+
+#### Face Inpainting
+
+
+
+
+
+### Dependencies and Installation
+
+- Pytorch >= 1.7.1
+- CUDA >= 10.1
+- Other required packages in `requirements.txt`
+```
+# git clone this repository
+git clone https://github.com/sczhou/CodeFormer
+cd CodeFormer
+
+# create new anaconda env
+conda create -n codeformer python=3.8 -y
+conda activate codeformer
+
+# install python dependencies
+pip3 install -r requirements.txt
+python basicsr/setup.py develop
+```
+
+
+### Quick Inference
+
+##### Download Pre-trained Models:
+Download the facelib pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by runing the following command.
+```
+python scripts/download_pretrained_models.py facelib
+```
+
+Download the CodeFormer pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by runing the following command.
+```
+python scripts/download_pretrained_models.py CodeFormer
+```
+
+##### Prepare Testing Data:
+You can put the testing images in the `inputs/TestWhole` folder. If you would like to test on cropped and aligned faces, you can put them in the `inputs/cropped_faces` folder.
+
+
+##### Testing on Face Restoration:
+```
+# For cropped and aligned faces
+python inference_codeformer.py --w 0.5 --has_aligned --test_path [input folder]
+
+# For the whole images
+# Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN
+# Add '--face_upsample' to further upsample restorated face with Real-ESRGAN
+python inference_codeformer.py --w 0.7 --test_path [input folder]
+```
+
+NOTE that *w* is in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result.
+
+The results will be saved in the `results` folder.
+
+### Citation
+If our work is useful for your research, please consider citing:
+
+ @article{zhou2022codeformer,
+ author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
+ title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
+ journal = {arXiv preprint arXiv:2206.11253},
+ year = {2022}
+ }
+
+### License
+
+
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
+
+### Acknowledgement
+
+This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). We also borrow some codes from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). Thanks for their awesome works.
+
+### Contact
+If you have any question, please feel free to reach me out at `shangchenzhou@gmail.com`.
\ No newline at end of file
diff --git a/repositories/CodeFormer/basicsr/VERSION b/repositories/CodeFormer/basicsr/VERSION
new file mode 100644
index 0000000000000000000000000000000000000000..1892b926767774e9ba91f1e584fa71b4c56abb69
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/VERSION
@@ -0,0 +1 @@
+1.3.2
diff --git a/repositories/CodeFormer/basicsr/__init__.py b/repositories/CodeFormer/basicsr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7ffcccd7fc0f33b59d99d73d0436d60e561b0fc
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/__init__.py
@@ -0,0 +1,11 @@
+# https://github.com/xinntao/BasicSR
+# flake8: noqa
+from .archs import *
+from .data import *
+from .losses import *
+from .metrics import *
+from .models import *
+from .ops import *
+from .train import *
+from .utils import *
+from .version import __gitsha__, __version__
diff --git a/repositories/CodeFormer/basicsr/archs/__init__.py b/repositories/CodeFormer/basicsr/archs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfb1e4d7bb221c429082bd389d9140e5b1cc07b0
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/archs/__init__.py
@@ -0,0 +1,25 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import ARCH_REGISTRY
+
+__all__ = ['build_network']
+
+# automatically scan and import arch modules for registry
+# scan all the files under the 'archs' folder and collect files ending with
+# '_arch.py'
+arch_folder = osp.dirname(osp.abspath(__file__))
+arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
+# import all the arch modules
+_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
+
+
+def build_network(opt):
+ opt = deepcopy(opt)
+ network_type = opt.pop('type')
+ net = ARCH_REGISTRY.get(network_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
+ return net
diff --git a/repositories/CodeFormer/basicsr/archs/arcface_arch.py b/repositories/CodeFormer/basicsr/archs/arcface_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe5afb7bd2b359e0c2b7efdf628ab10b63964d87
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/archs/arcface_arch.py
@@ -0,0 +1,245 @@
+import torch.nn as nn
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+def conv3x3(inplanes, outplanes, stride=1):
+ """A simple wrapper for 3x3 convolution with padding.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ outplanes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ """
+ return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ """Basic residual block used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ """
+ expansion = 1 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class IRBlock(nn.Module):
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
+ """
+ expansion = 1 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
+ super(IRBlock, self).__init__()
+ self.bn0 = nn.BatchNorm2d(inplanes)
+ self.conv1 = conv3x3(inplanes, inplanes)
+ self.bn1 = nn.BatchNorm2d(inplanes)
+ self.prelu = nn.PReLU()
+ self.conv2 = conv3x3(inplanes, planes, stride)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.use_se = use_se
+ if self.use_se:
+ self.se = SEBlock(planes)
+
+ def forward(self, x):
+ residual = x
+ out = self.bn0(x)
+ out = self.conv1(out)
+ out = self.bn1(out)
+ out = self.prelu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ if self.use_se:
+ out = self.se(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.prelu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ """Bottleneck block used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ """
+ expansion = 4 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class SEBlock(nn.Module):
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
+
+ Args:
+ channel (int): Channel number of inputs.
+ reduction (int): Channel reduction ration. Default: 16.
+ """
+
+ def __init__(self, channel, reduction=16):
+ super(SEBlock, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
+ nn.Sigmoid())
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y
+
+
+@ARCH_REGISTRY.register()
+class ResNetArcFace(nn.Module):
+ """ArcFace with ResNet architectures.
+
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
+
+ Args:
+ block (str): Block used in the ArcFace architecture.
+ layers (tuple(int)): Block numbers in each layer.
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
+ """
+
+ def __init__(self, block, layers, use_se=True):
+ if block == 'IRBlock':
+ block = IRBlock
+ self.inplanes = 64
+ self.use_se = use_se
+ super(ResNetArcFace, self).__init__()
+
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.prelu = nn.PReLU()
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.bn4 = nn.BatchNorm2d(512)
+ self.dropout = nn.Dropout()
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
+ self.bn5 = nn.BatchNorm1d(512)
+
+ # initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.xavier_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, num_blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
+ self.inplanes = planes
+ for _ in range(1, num_blocks):
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.bn4(x)
+ x = self.dropout(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc5(x)
+ x = self.bn5(x)
+
+ return x
\ No newline at end of file
diff --git a/repositories/CodeFormer/basicsr/archs/arch_util.py b/repositories/CodeFormer/basicsr/archs/arch_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..bad45ab34e901c47fb539152fca714a3795b0de2
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/archs/arch_util.py
@@ -0,0 +1,318 @@
+import collections.abc
+import math
+import torch
+import torchvision
+import warnings
+from distutils.version import LooseVersion
+from itertools import repeat
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn import init as init
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
+from basicsr.utils import get_root_logger
+
+
+@torch.no_grad()
+def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
+ """Initialize network weights.
+
+ Args:
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
+ scale (float): Scale initialized weights, especially for residual
+ blocks. Default: 1.
+ bias_fill (float): The value to fill bias. Default: 0
+ kwargs (dict): Other arguments for initialization function.
+ """
+ if not isinstance(module_list, list):
+ module_list = [module_list]
+ for module in module_list:
+ for m in module.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, _BatchNorm):
+ init.constant_(m.weight, 1)
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+
+ Args:
+ basic_block (nn.module): nn.module class for basic block.
+ num_basic_block (int): number of blocks.
+
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+
+class ResidualBlockNoBN(nn.Module):
+ """Residual block without BN.
+
+ It has a style of:
+ ---Conv-ReLU-Conv-+-
+ |________________|
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ Default: 64.
+ res_scale (float): Residual scale. Default: 1.
+ pytorch_init (bool): If set to True, use pytorch default init,
+ otherwise, use default_init_weights. Default: False.
+ """
+
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
+ super(ResidualBlockNoBN, self).__init__()
+ self.res_scale = res_scale
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.relu = nn.ReLU(inplace=True)
+
+ if not pytorch_init:
+ default_init_weights([self.conv1, self.conv2], 0.1)
+
+ def forward(self, x):
+ identity = x
+ out = self.conv2(self.relu(self.conv1(x)))
+ return identity + out * self.res_scale
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+
+def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
+ """Warp an image or feature map with optical flow.
+
+ Args:
+ x (Tensor): Tensor with size (n, c, h, w).
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
+ Default: 'zeros'.
+ align_corners (bool): Before pytorch 1.3, the default value is
+ align_corners=True. After pytorch 1.3, the default value is
+ align_corners=False. Here, we use the True as default.
+
+ Returns:
+ Tensor: Warped image or feature map.
+ """
+ assert x.size()[-2:] == flow.size()[1:3]
+ _, _, h, w = x.size()
+ # create mesh grid
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
+ grid.requires_grad = False
+
+ vgrid = grid + flow
+ # scale grid to [-1,1]
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
+
+ # TODO, what if align_corners=False
+ return output
+
+
+def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
+ """Resize a flow according to ratio or shape.
+
+ Args:
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
+ size_type (str): 'ratio' or 'shape'.
+ sizes (list[int | float]): the ratio for resizing or the final output
+ shape.
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
+ ratio > 1.0).
+ 2) The order of output_size should be [out_h, out_w].
+ interp_mode (str): The mode of interpolation for resizing.
+ Default: 'bilinear'.
+ align_corners (bool): Whether align corners. Default: False.
+
+ Returns:
+ Tensor: Resized flow.
+ """
+ _, _, flow_h, flow_w = flow.size()
+ if size_type == 'ratio':
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
+ elif size_type == 'shape':
+ output_h, output_w = sizes[0], sizes[1]
+ else:
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
+
+ input_flow = flow.clone()
+ ratio_h = output_h / flow_h
+ ratio_w = output_w / flow_w
+ input_flow[:, 0, :, :] *= ratio_w
+ input_flow[:, 1, :, :] *= ratio_h
+ resized_flow = F.interpolate(
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
+ return resized_flow
+
+
+# TODO: may write a cpp file
+def pixel_unshuffle(x, scale):
+ """ Pixel unshuffle.
+
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+class DCNv2Pack(ModulatedDeformConvPack):
+ """Modulated deformable conv for deformable alignment.
+
+ Different from the official DCNv2Pack, which generates offsets and masks
+ from the preceding features, this DCNv2Pack takes another different
+ features to generate offsets and masks.
+
+ Ref:
+ Delving Deep into Deformable Alignment in Video Super-Resolution.
+ """
+
+ def forward(self, x, feat):
+ out = self.conv_offset(feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+
+ offset_absmean = torch.mean(torch.abs(offset))
+ if offset_absmean > 50:
+ logger = get_root_logger()
+ logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
+
+ if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, mask)
+ else:
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups, self.deformable_groups)
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ low = norm_cdf((a - mean) / std)
+ up = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [low, up], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution.
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+
+ The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+# From PyTorch
+def _ntuple(n):
+
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
\ No newline at end of file
diff --git a/repositories/CodeFormer/basicsr/archs/codeformer_arch.py b/repositories/CodeFormer/basicsr/archs/codeformer_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d0d8027c8c4ffb26af6f4ba361514e93e320e8d
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/archs/codeformer_arch.py
@@ -0,0 +1,276 @@
+import math
+import numpy as np
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+from typing import Optional, List
+
+from basicsr.archs.vqgan_arch import *
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+ return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ """Adaptive instance normalization.
+
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class TransformerSALayer(nn.Module):
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
+ # Implementation of Feedforward model - MLP
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
+
+ self.norm1 = nn.LayerNorm(embed_dim)
+ self.norm2 = nn.LayerNorm(embed_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward(self, tgt,
+ tgt_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+
+ # self attention
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+
+ # ffn
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout2(tgt2)
+ return tgt
+
+class Fuse_sft_block(nn.Module):
+ def __init__(self, in_ch, out_ch):
+ super().__init__()
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
+
+ self.scale = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ self.shift = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ def forward(self, enc_feat, dec_feat, w=1):
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
+ scale = self.scale(enc_feat)
+ shift = self.shift(enc_feat)
+ residual = w * (dec_feat * scale + shift)
+ out = dec_feat + residual
+ return out
+
+
+@ARCH_REGISTRY.register()
+class CodeFormer(VQAutoEncoder):
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
+ codebook_size=1024, latent_size=256,
+ connect_list=['32', '64', '128', '256'],
+ fix_modules=['quantize','generator']):
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
+
+ if fix_modules is not None:
+ for module in fix_modules:
+ for param in getattr(self, module).parameters():
+ param.requires_grad = False
+
+ self.connect_list = connect_list
+ self.n_layers = n_layers
+ self.dim_embd = dim_embd
+ self.dim_mlp = dim_embd*2
+
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
+ self.feat_emb = nn.Linear(256, self.dim_embd)
+
+ # transformer
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
+ for _ in range(self.n_layers)])
+
+ # logits_predict head
+ self.idx_pred_layer = nn.Sequential(
+ nn.LayerNorm(dim_embd),
+ nn.Linear(dim_embd, codebook_size, bias=False))
+
+ self.channels = {
+ '16': 512,
+ '32': 256,
+ '64': 256,
+ '128': 128,
+ '256': 128,
+ '512': 64,
+ }
+
+ # after second residual block for > 16, before attn layer for ==16
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
+ # after first residual block for > 16, before attn layer for ==16
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
+
+ # fuse_convs_dict
+ self.fuse_convs_dict = nn.ModuleDict()
+ for f_size in self.connect_list:
+ in_ch = self.channels[f_size]
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
+ # ################### Encoder #####################
+ enc_feat_dict = {}
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
+ for i, block in enumerate(self.encoder.blocks):
+ x = block(x)
+ if i in out_list:
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
+
+ lq_feat = x
+ # ################# Transformer ###################
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
+ # BCHW -> BC(HW) -> (HW)BC
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
+ query_emb = feat_emb
+ # Transformer encoder
+ for layer in self.ft_layers:
+ query_emb = layer(query_emb, query_pos=pos_emb)
+
+ # output logits
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
+
+ if code_only: # for training stage II
+ # logits doesn't need softmax before cross_entropy loss
+ return logits, lq_feat
+
+ # ################# Quantization ###################
+ # if self.training:
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
+ # # b(hw)c -> bc(hw) -> bchw
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
+ # ------------
+ soft_one_hot = F.softmax(logits, dim=2)
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
+ # preserve gradients
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
+
+ if detach_16:
+ quant_feat = quant_feat.detach() # for training stage III
+ if adain:
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
+
+ # ################## Generator ####################
+ x = quant_feat
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
+
+ for i, block in enumerate(self.generator.blocks):
+ x = block(x)
+ if i in fuse_list: # fuse after i-th block
+ f_size = str(x.shape[-1])
+ if w>0:
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
+ out = x
+ # logits doesn't need softmax before cross_entropy loss
+ return out, logits, lq_feat
\ No newline at end of file
diff --git a/repositories/CodeFormer/basicsr/archs/rrdbnet_arch.py b/repositories/CodeFormer/basicsr/archs/rrdbnet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..49a2d6c204557cba53ada7550deb587541855cfb
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/archs/rrdbnet_arch.py
@@ -0,0 +1,119 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import default_init_weights, make_layer, pixel_unshuffle
+
+
+class ResidualDenseBlock(nn.Module):
+ """Residual Dense Block.
+
+ Used in RRDB block in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat=64, num_grow_ch=32):
+ super(ResidualDenseBlock, self).__init__()
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ # initialization
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ # Emperically, we use 0.2 to scale the residual for better performance
+ return x5 * 0.2 + x
+
+
+class RRDB(nn.Module):
+ """Residual in Residual Dense Block.
+
+ Used in RRDB-Net in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat, num_grow_ch=32):
+ super(RRDB, self).__init__()
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
+
+ def forward(self, x):
+ out = self.rdb1(x)
+ out = self.rdb2(out)
+ out = self.rdb3(out)
+ # Emperically, we use 0.2 to scale the residual for better performance
+ return out * 0.2 + x
+
+
+@ARCH_REGISTRY.register()
+class RRDBNet(nn.Module):
+ """Networks consisting of Residual in Residual Dense Block, which is used
+ in ESRGAN.
+
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
+
+ We extend ESRGAN for scale x2 and scale x1.
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_feat (int): Channel number of intermediate features.
+ Default: 64
+ num_block (int): Block number in the trunk network. Defaults: 23
+ num_grow_ch (int): Channels for each growth. Default: 32.
+ """
+
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
+ super(RRDBNet, self).__init__()
+ self.scale = scale
+ if scale == 2:
+ num_in_ch = num_in_ch * 4
+ elif scale == 1:
+ num_in_ch = num_in_ch * 16
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ # upsample
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ if self.scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ elif self.scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ else:
+ feat = x
+ feat = self.conv_first(feat)
+ body_feat = self.conv_body(self.body(feat))
+ feat = feat + body_feat
+ # upsample
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
+ return out
\ No newline at end of file
diff --git a/repositories/CodeFormer/basicsr/archs/vgg_arch.py b/repositories/CodeFormer/basicsr/archs/vgg_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..23bb0103c8b14ef2588028f7177753db9af62cae
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/archs/vgg_arch.py
@@ -0,0 +1,161 @@
+import os
+import torch
+from collections import OrderedDict
+from torch import nn as nn
+from torchvision.models import vgg as vgg
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
+NAMES = {
+ 'vgg11': [
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
+ 'pool5'
+ ],
+ 'vgg13': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
+ ],
+ 'vgg16': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
+ 'pool5'
+ ],
+ 'vgg19': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
+ ]
+}
+
+
+def insert_bn(names):
+ """Insert bn layer after each conv.
+
+ Args:
+ names (list): The list of layer names.
+
+ Returns:
+ list: The list of layer names with bn layers.
+ """
+ names_bn = []
+ for name in names:
+ names_bn.append(name)
+ if 'conv' in name:
+ position = name.replace('conv', '')
+ names_bn.append('bn' + position)
+ return names_bn
+
+
+@ARCH_REGISTRY.register()
+class VGGFeatureExtractor(nn.Module):
+ """VGG network for feature extraction.
+
+ In this implementation, we allow users to choose whether use normalization
+ in the input feature and the type of vgg network. Note that the pretrained
+ path must fit the vgg type.
+
+ Args:
+ layer_name_list (list[str]): Forward function returns the corresponding
+ features according to the layer_name_list.
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image. Importantly,
+ the input feature must in the range [0, 1]. Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ requires_grad (bool): If true, the parameters of VGG network will be
+ optimized. Default: False.
+ remove_pooling (bool): If true, the max pooling operations in VGG net
+ will be removed. Default: False.
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
+ """
+
+ def __init__(self,
+ layer_name_list,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ requires_grad=False,
+ remove_pooling=False,
+ pooling_stride=2):
+ super(VGGFeatureExtractor, self).__init__()
+
+ self.layer_name_list = layer_name_list
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ self.names = NAMES[vgg_type.replace('_bn', '')]
+ if 'bn' in vgg_type:
+ self.names = insert_bn(self.names)
+
+ # only borrow layers that will be used to avoid unused params
+ max_idx = 0
+ for v in layer_name_list:
+ idx = self.names.index(v)
+ if idx > max_idx:
+ max_idx = idx
+
+ if os.path.exists(VGG_PRETRAIN_PATH):
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
+ vgg_net.load_state_dict(state_dict)
+ else:
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
+
+ features = vgg_net.features[:max_idx + 1]
+
+ modified_net = OrderedDict()
+ for k, v in zip(self.names, features):
+ if 'pool' in k:
+ # if remove_pooling is true, pooling operation will be removed
+ if remove_pooling:
+ continue
+ else:
+ # in some cases, we may want to change the default stride
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
+ else:
+ modified_net[k] = v
+
+ self.vgg_net = nn.Sequential(modified_net)
+
+ if not requires_grad:
+ self.vgg_net.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ else:
+ self.vgg_net.train()
+ for param in self.parameters():
+ param.requires_grad = True
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ if self.range_norm:
+ x = (x + 1) / 2
+ if self.use_input_norm:
+ x = (x - self.mean) / self.std
+ output = {}
+
+ for key, layer in self.vgg_net._modules.items():
+ x = layer(x)
+ if key in self.layer_name_list:
+ output[key] = x.clone()
+
+ return output
diff --git a/repositories/CodeFormer/basicsr/archs/vqgan_arch.py b/repositories/CodeFormer/basicsr/archs/vqgan_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6dfcf4c9983b431f0a978701e5ddd9598faf381
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/archs/vqgan_arch.py
@@ -0,0 +1,435 @@
+'''
+VQGAN code, adapted from the original created by the Unleashing Transformers authors:
+https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
+
+'''
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+def normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+@torch.jit.script
+def swish(x):
+ return x*torch.sigmoid(x)
+
+
+# Define VQVAE classes
+class VectorQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.emb_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
+
+ mean_distance = torch.mean(d)
+ # find closest encodings
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+ min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
+ # [0-1], higher score, higher confidence
+ min_encoding_scores = torch.exp(-min_encoding_scores/10)
+
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, {
+ "perplexity": perplexity,
+ "min_encodings": min_encodings,
+ "min_encoding_indices": min_encoding_indices,
+ "min_encoding_scores": min_encoding_scores,
+ "mean_distance": mean_distance
+ }
+
+ def get_codebook_feat(self, indices, shape):
+ # input indices: batch*token_num -> (batch*token_num)*1
+ # shape: batch, height, width, channel
+ indices = indices.view(-1,1)
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
+ min_encodings.scatter_(1, indices, 1)
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None: # reshape back to match original input shape
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
+ super().__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
+ self.embed = nn.Embedding(codebook_size, emb_dim)
+
+ def forward(self, z):
+ hard = self.straight_through if self.training else True
+
+ logits = self.proj(z)
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
+
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
+
+ return z_q, diff, {
+ "min_encoding_indices": min_encoding_indices
+ }
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+
+ return x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels, out_channels=None):
+ super(ResBlock, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.norm1 = normalize(in_channels)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = normalize(out_channels)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x_in):
+ x = x_in
+ x = self.norm1(x)
+ x = swish(x)
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = swish(x)
+ x = self.conv2(x)
+ if self.in_channels != self.out_channels:
+ x_in = self.conv_out(x_in)
+
+ return x + x_in
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h*w)
+ q = q.permute(0, 2, 1)
+ k = k.reshape(b, c, h*w)
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = F.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h*w)
+ w_ = w_.permute(0, 2, 1)
+ h_ = torch.bmm(v, w_)
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.attn_resolutions = attn_resolutions
+
+ curr_res = self.resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+
+ blocks = []
+ # initial convultion
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
+
+ # residual and downsampling blocks, with attention on smaller res (16x16)
+ for i in range(self.num_resolutions):
+ block_in_ch = nf * in_ch_mult[i]
+ block_out_ch = nf * ch_mult[i]
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+ if curr_res in attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != self.num_resolutions - 1:
+ blocks.append(Downsample(block_in_ch))
+ curr_res = curr_res // 2
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ # normalise and convert to latent size
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+class Generator(nn.Module):
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.ch_mult = ch_mult
+ self.num_resolutions = len(self.ch_mult)
+ self.num_res_blocks = res_blocks
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.in_channels = emb_dim
+ self.out_channels = 3
+ block_in_ch = self.nf * self.ch_mult[-1]
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
+
+ blocks = []
+ # initial conv
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ for i in reversed(range(self.num_resolutions)):
+ block_out_ch = self.nf * self.ch_mult[i]
+
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+
+ if curr_res in self.attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != 0:
+ blocks.append(Upsample(block_in_ch))
+ curr_res = curr_res * 2
+
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
+
+ self.blocks = nn.ModuleList(blocks)
+
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+@ARCH_REGISTRY.register()
+class VQAutoEncoder(nn.Module):
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
+ super().__init__()
+ logger = get_root_logger()
+ self.in_channels = 3
+ self.nf = nf
+ self.n_blocks = res_blocks
+ self.codebook_size = codebook_size
+ self.embed_dim = emb_dim
+ self.ch_mult = ch_mult
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.quantizer_type = quantizer
+ self.encoder = Encoder(
+ self.in_channels,
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+ if self.quantizer_type == "nearest":
+ self.beta = beta #0.25
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
+ elif self.quantizer_type == "gumbel":
+ self.gumbel_num_hiddens = emb_dim
+ self.straight_through = gumbel_straight_through
+ self.kl_weight = gumbel_kl_weight
+ self.quantize = GumbelQuantizer(
+ self.codebook_size,
+ self.embed_dim,
+ self.gumbel_num_hiddens,
+ self.straight_through,
+ self.kl_weight
+ )
+ self.generator = Generator(
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location='cpu')
+ if 'params_ema' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
+ elif 'params' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
+ else:
+ raise ValueError(f'Wrong params!')
+
+
+ def forward(self, x):
+ x = self.encoder(x)
+ quant, codebook_loss, quant_stats = self.quantize(x)
+ x = self.generator(quant)
+ return x, codebook_loss, quant_stats
+
+
+
+# patch based discriminator
+@ARCH_REGISTRY.register()
+class VQGANDiscriminator(nn.Module):
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
+ super().__init__()
+
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
+ ndf_mult = 1
+ ndf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n, 8)
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n_layers, 8)
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
+ self.main = nn.Sequential(*layers)
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location='cpu')
+ if 'params_d' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
+ elif 'params' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+ else:
+ raise ValueError(f'Wrong params!')
+
+ def forward(self, x):
+ return self.main(x)
\ No newline at end of file
diff --git a/repositories/CodeFormer/basicsr/data/__init__.py b/repositories/CodeFormer/basicsr/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6adb4bb6a926af7a46aaec4794eee95fda02a33
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/data/__init__.py
@@ -0,0 +1,100 @@
+import importlib
+import numpy as np
+import random
+import torch
+import torch.utils.data
+from copy import deepcopy
+from functools import partial
+from os import path as osp
+
+from basicsr.data.prefetch_dataloader import PrefetchDataLoader
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.dist_util import get_dist_info
+from basicsr.utils.registry import DATASET_REGISTRY
+
+__all__ = ['build_dataset', 'build_dataloader']
+
+# automatically scan and import dataset modules for registry
+# scan all the files under the data folder with '_dataset' in file names
+data_folder = osp.dirname(osp.abspath(__file__))
+dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
+# import all the dataset modules
+_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
+
+
+def build_dataset(dataset_opt):
+ """Build dataset from options.
+
+ Args:
+ dataset_opt (dict): Configuration for dataset. It must constain:
+ name (str): Dataset name.
+ type (str): Dataset type.
+ """
+ dataset_opt = deepcopy(dataset_opt)
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
+ logger = get_root_logger()
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
+ return dataset
+
+
+def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
+ """Build dataloader.
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset.
+ dataset_opt (dict): Dataset options. It contains the following keys:
+ phase (str): 'train' or 'val'.
+ num_worker_per_gpu (int): Number of workers for each GPU.
+ batch_size_per_gpu (int): Training batch size for each GPU.
+ num_gpu (int): Number of GPUs. Used only in the train phase.
+ Default: 1.
+ dist (bool): Whether in distributed training. Used only in the train
+ phase. Default: False.
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
+ seed (int | None): Seed. Default: None
+ """
+ phase = dataset_opt['phase']
+ rank, _ = get_dist_info()
+ if phase == 'train':
+ if dist: # distributed training
+ batch_size = dataset_opt['batch_size_per_gpu']
+ num_workers = dataset_opt['num_worker_per_gpu']
+ else: # non-distributed training
+ multiplier = 1 if num_gpu == 0 else num_gpu
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
+ dataloader_args = dict(
+ dataset=dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ sampler=sampler,
+ drop_last=True)
+ if sampler is None:
+ dataloader_args['shuffle'] = True
+ dataloader_args['worker_init_fn'] = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
+ elif phase in ['val', 'test']: # validation
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+ else:
+ raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
+
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
+
+ prefetch_mode = dataset_opt.get('prefetch_mode')
+ if prefetch_mode == 'cpu': # CPUPrefetcher
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
+ logger = get_root_logger()
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
+ else:
+ # prefetch_mode=None: Normal dataloader
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
+ return torch.utils.data.DataLoader(**dataloader_args)
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # Set the worker seed to num_workers * rank + worker_id + seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/repositories/CodeFormer/basicsr/data/data_sampler.py b/repositories/CodeFormer/basicsr/data/data_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..575452d9f844a928f7f42296c81635cfbadec7c2
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/data/data_sampler.py
@@ -0,0 +1,48 @@
+import math
+import torch
+from torch.utils.data.sampler import Sampler
+
+
+class EnlargedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ Modified from torch.utils.data.distributed.DistributedSampler
+ Support enlarging the dataset for iteration-based training, for saving
+ time when restart the dataloader after each epoch
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
+ num_replicas (int | None): Number of processes participating in
+ the training. It is usually the world_size.
+ rank (int | None): Rank of the current process within num_replicas.
+ ratio (int): Enlarging ratio. Default: 1.
+ """
+
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(self.total_size, generator=g).tolist()
+
+ dataset_size = len(self.dataset)
+ indices = [v % dataset_size for v in indices]
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/repositories/CodeFormer/basicsr/data/data_util.py b/repositories/CodeFormer/basicsr/data/data_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b1bce8e089485182c962e830a163d6d0059da8
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/data/data_util.py
@@ -0,0 +1,305 @@
+import cv2
+import numpy as np
+import torch
+from os import path as osp
+from torch.nn import functional as F
+
+from basicsr.data.transforms import mod_crop
+from basicsr.utils import img2tensor, scandir
+
+
+def read_img_seq(path, require_mod_crop=False, scale=1):
+ """Read a sequence of images from a given folder path.
+
+ Args:
+ path (list[str] | str): List of image paths or image folder path.
+ require_mod_crop (bool): Require mod crop for each image.
+ Default: False.
+ scale (int): Scale factor for mod_crop. Default: 1.
+
+ Returns:
+ Tensor: size (t, c, h, w), RGB, [0, 1].
+ """
+ if isinstance(path, list):
+ img_paths = path
+ else:
+ img_paths = sorted(list(scandir(path, full_path=True)))
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
+ if require_mod_crop:
+ imgs = [mod_crop(img, scale) for img in imgs]
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
+ imgs = torch.stack(imgs, dim=0)
+ return imgs
+
+
+def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
+ """Generate an index list for reading `num_frames` frames from a sequence
+ of images.
+
+ Args:
+ crt_idx (int): Current center index.
+ max_frame_num (int): Max number of the sequence of images (from 1).
+ num_frames (int): Reading num_frames frames.
+ padding (str): Padding mode, one of
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
+ Examples: current_idx = 0, num_frames = 5
+ The generated frame indices under different padding mode:
+ replicate: [0, 0, 0, 1, 2]
+ reflection: [2, 1, 0, 1, 2]
+ reflection_circle: [4, 3, 0, 1, 2]
+ circle: [3, 4, 0, 1, 2]
+
+ Returns:
+ list[int]: A list of indices.
+ """
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
+
+ max_frame_num = max_frame_num - 1 # start from 0
+ num_pad = num_frames // 2
+
+ indices = []
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
+ if i < 0:
+ if padding == 'replicate':
+ pad_idx = 0
+ elif padding == 'reflection':
+ pad_idx = -i
+ elif padding == 'reflection_circle':
+ pad_idx = crt_idx + num_pad - i
+ else:
+ pad_idx = num_frames + i
+ elif i > max_frame_num:
+ if padding == 'replicate':
+ pad_idx = max_frame_num
+ elif padding == 'reflection':
+ pad_idx = max_frame_num * 2 - i
+ elif padding == 'reflection_circle':
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
+ else:
+ pad_idx = i - num_frames
+ else:
+ pad_idx = i
+ indices.append(pad_idx)
+ return indices
+
+
+def paired_paths_from_lmdb(folders, keys):
+ """Generate paired paths from lmdb files.
+
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
+
+ lq.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records
+ 1)image name (with extension),
+ 2)image shape,
+ 3)compression level, separated by a white space.
+ Example: `baboon.png (120,125,3) 1`
+
+ We use the image name without extension as the lmdb key.
+ Note that we use the same key for the corresponding lq and gt images.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ Note that this key is different from lmdb keys.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
+ f'formats. But received {input_key}: {input_folder}; '
+ f'{gt_key}: {gt_folder}')
+ # ensure that the two meta_info files are the same
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
+ else:
+ paths = []
+ for lmdb_key in sorted(input_lmdb_keys):
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
+ return paths
+
+
+def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
+ """Generate paired paths from an meta information file.
+
+ Each line in the meta information file contains the image names and
+ image shape (usually for gt), separated by a white space.
+
+ Example of an meta information file:
+ ```
+ 0001_s001.png (480,480,3)
+ 0001_s002.png (480,480,3)
+ ```
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ meta_info_file (str): Path to the meta information file.
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ with open(meta_info_file, 'r') as fin:
+ gt_names = [line.split(' ')[0] for line in fin]
+
+ paths = []
+ for gt_name in gt_names:
+ basename, ext = osp.splitext(osp.basename(gt_name))
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
+ input_path = osp.join(input_folder, input_name)
+ gt_path = osp.join(gt_folder, gt_name)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+
+def paired_paths_from_folder(folders, keys, filename_tmpl):
+ """Generate paired paths from folders.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ input_paths = list(scandir(input_folder))
+ gt_paths = list(scandir(gt_folder))
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
+ f'{len(input_paths)}, {len(gt_paths)}.')
+ paths = []
+ for gt_path in gt_paths:
+ basename, ext = osp.splitext(osp.basename(gt_path))
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
+ input_path = osp.join(input_folder, input_name)
+ assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
+ gt_path = osp.join(gt_folder, gt_path)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+
+def paths_from_folder(folder):
+ """Generate paths from folder.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+
+ paths = list(scandir(folder))
+ paths = [osp.join(folder, path) for path in paths]
+ return paths
+
+
+def paths_from_lmdb(folder):
+ """Generate paths from lmdb.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ if not folder.endswith('.lmdb'):
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
+ paths = [line.split('.')[0] for line in fin]
+ return paths
+
+
+def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
+ """Generate Gaussian kernel used in `duf_downsample`.
+
+ Args:
+ kernel_size (int): Kernel size. Default: 13.
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
+
+ Returns:
+ np.array: The Gaussian kernel.
+ """
+ from scipy.ndimage import filters as filters
+ kernel = np.zeros((kernel_size, kernel_size))
+ # set element at the middle to one, a dirac delta
+ kernel[kernel_size // 2, kernel_size // 2] = 1
+ # gaussian-smooth the dirac, resulting in a gaussian filter
+ return filters.gaussian_filter(kernel, sigma)
+
+
+def duf_downsample(x, kernel_size=13, scale=4):
+ """Downsamping with Gaussian kernel used in the DUF official code.
+
+ Args:
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
+ kernel_size (int): Kernel size. Default: 13.
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
+ Default: 4.
+
+ Returns:
+ Tensor: DUF downsampled frames.
+ """
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
+
+ squeeze_flag = False
+ if x.ndim == 4:
+ squeeze_flag = True
+ x = x.unsqueeze(0)
+ b, t, c, h, w = x.size()
+ x = x.view(-1, 1, h, w)
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
+
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
+ x = F.conv2d(x, gaussian_filter, stride=scale)
+ x = x[:, :, 2:-2, 2:-2]
+ x = x.view(b, t, c, x.size(2), x.size(3))
+ if squeeze_flag:
+ x = x.squeeze(0)
+ return x
diff --git a/repositories/CodeFormer/basicsr/data/prefetch_dataloader.py b/repositories/CodeFormer/basicsr/data/prefetch_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..5088425050d4cc98114a9b93eb50ea60273f35a0
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/data/prefetch_dataloader.py
@@ -0,0 +1,125 @@
+import queue as Queue
+import threading
+import torch
+from torch.utils.data import DataLoader
+
+
+class PrefetchGenerator(threading.Thread):
+ """A general prefetch generator.
+
+ Ref:
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
+
+ Args:
+ generator: Python generator.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, generator, num_prefetch_queue):
+ threading.Thread.__init__(self)
+ self.queue = Queue.Queue(num_prefetch_queue)
+ self.generator = generator
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def __next__(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class PrefetchDataLoader(DataLoader):
+ """Prefetch version of dataloader.
+
+ Ref:
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
+
+ TODO:
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
+ ddp.
+
+ Args:
+ num_prefetch_queue (int): Number of prefetch queue.
+ kwargs (dict): Other arguments for dataloader.
+ """
+
+ def __init__(self, num_prefetch_queue, **kwargs):
+ self.num_prefetch_queue = num_prefetch_queue
+ super(PrefetchDataLoader, self).__init__(**kwargs)
+
+ def __iter__(self):
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
+
+
+class CPUPrefetcher():
+ """CPU prefetcher.
+
+ Args:
+ loader: Dataloader.
+ """
+
+ def __init__(self, loader):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+
+ def next(self):
+ try:
+ return next(self.loader)
+ except StopIteration:
+ return None
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+
+
+class CUDAPrefetcher():
+ """CUDA prefetcher.
+
+ Ref:
+ https://github.com/NVIDIA/apex/issues/304#
+
+ It may consums more GPU memory.
+
+ Args:
+ loader: Dataloader.
+ opt (dict): Options.
+ """
+
+ def __init__(self, loader, opt):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+ self.opt = opt
+ self.stream = torch.cuda.Stream()
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+ self.preload()
+
+ def preload(self):
+ try:
+ self.batch = next(self.loader) # self.batch is a dict
+ except StopIteration:
+ self.batch = None
+ return None
+ # put tensors to gpu
+ with torch.cuda.stream(self.stream):
+ for k, v in self.batch.items():
+ if torch.is_tensor(v):
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
+
+ def next(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ self.preload()
+ return batch
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+ self.preload()
diff --git a/repositories/CodeFormer/basicsr/data/transforms.py b/repositories/CodeFormer/basicsr/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..aead9dc73ed063e1c5865040eaa2652b26aa3ad3
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/data/transforms.py
@@ -0,0 +1,165 @@
+import cv2
+import random
+
+
+def mod_crop(img, scale):
+ """Mod crop images, used during testing.
+
+ Args:
+ img (ndarray): Input image.
+ scale (int): Scale factor.
+
+ Returns:
+ ndarray: Result image.
+ """
+ img = img.copy()
+ if img.ndim in (2, 3):
+ h, w = img.shape[0], img.shape[1]
+ h_remainder, w_remainder = h % scale, w % scale
+ img = img[:h - h_remainder, :w - w_remainder, ...]
+ else:
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
+ return img
+
+
+def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
+ """Paired random crop.
+
+ It crops lists of lq and gt images with corresponding locations.
+
+ Args:
+ img_gts (list[ndarray] | ndarray): GT images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ gt_patch_size (int): GT patch size.
+ scale (int): Scale factor.
+ gt_path (str): Path to ground-truth.
+
+ Returns:
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
+ only have one element, just return ndarray.
+ """
+
+ if not isinstance(img_gts, list):
+ img_gts = [img_gts]
+ if not isinstance(img_lqs, list):
+ img_lqs = [img_lqs]
+
+ h_lq, w_lq, _ = img_lqs[0].shape
+ h_gt, w_gt, _ = img_gts[0].shape
+ lq_patch_size = gt_patch_size // scale
+
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
+ f'({lq_patch_size}, {lq_patch_size}). '
+ f'Please remove {gt_path}.')
+
+ # randomly choose top and left coordinates for lq patch
+ top = random.randint(0, h_lq - lq_patch_size)
+ left = random.randint(0, w_lq - lq_patch_size)
+
+ # crop lq patch
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
+
+ # crop corresponding gt patch
+ top_gt, left_gt = int(top * scale), int(left * scale)
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
+ if len(img_gts) == 1:
+ img_gts = img_gts[0]
+ if len(img_lqs) == 1:
+ img_lqs = img_lqs[0]
+ return img_gts, img_lqs
+
+
+def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
+
+ We use vertical flip and transpose for rotation implementation.
+ All the images in the list use the same augmentation.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
+ is an ndarray, it will be transformed to a list.
+ hflip (bool): Horizontal flip. Default: True.
+ rotation (bool): Ratotation. Default: True.
+ flows (list[ndarray]: Flows to be augmented. If the input is an
+ ndarray, it will be transformed to a list.
+ Dimension is (h, w, 2). Default: None.
+ return_status (bool): Return the status of flip and rotation.
+ Default: False.
+
+ Returns:
+ list[ndarray] | ndarray: Augmented images and flows. If returned
+ results only have one element, just return ndarray.
+
+ """
+ hflip = hflip and random.random() < 0.5
+ vflip = rotation and random.random() < 0.5
+ rot90 = rotation and random.random() < 0.5
+
+ def _augment(img):
+ if hflip: # horizontal
+ cv2.flip(img, 1, img)
+ if vflip: # vertical
+ cv2.flip(img, 0, img)
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ def _augment_flow(flow):
+ if hflip: # horizontal
+ cv2.flip(flow, 1, flow)
+ flow[:, :, 0] *= -1
+ if vflip: # vertical
+ cv2.flip(flow, 0, flow)
+ flow[:, :, 1] *= -1
+ if rot90:
+ flow = flow.transpose(1, 0, 2)
+ flow = flow[:, :, [1, 0]]
+ return flow
+
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ imgs = [_augment(img) for img in imgs]
+ if len(imgs) == 1:
+ imgs = imgs[0]
+
+ if flows is not None:
+ if not isinstance(flows, list):
+ flows = [flows]
+ flows = [_augment_flow(flow) for flow in flows]
+ if len(flows) == 1:
+ flows = flows[0]
+ return imgs, flows
+ else:
+ if return_status:
+ return imgs, (hflip, vflip, rot90)
+ else:
+ return imgs
+
+
+def img_rotate(img, angle, center=None, scale=1.0):
+ """Rotate image.
+
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees. Positive values mean
+ counter-clockwise rotation.
+ center (tuple[int]): Rotation center. If the center is None,
+ initialize it as the center of the image. Default: None.
+ scale (float): Isotropic scale factor. Default: 1.0.
+ """
+ (h, w) = img.shape[:2]
+
+ if center is None:
+ center = (w // 2, h // 2)
+
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
+ return rotated_img
diff --git a/repositories/CodeFormer/basicsr/losses/__init__.py b/repositories/CodeFormer/basicsr/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b184e74c861e6fca0c548692a9a949a6100b0aa
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/losses/__init__.py
@@ -0,0 +1,26 @@
+from copy import deepcopy
+
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import LOSS_REGISTRY
+from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
+ gradient_penalty_loss, r1_penalty)
+
+__all__ = [
+ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
+ 'r1_penalty', 'g_path_regularize'
+]
+
+
+def build_loss(opt):
+ """Build loss from options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ loss_type = opt.pop('type')
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
+ return loss
diff --git a/repositories/CodeFormer/basicsr/losses/loss_util.py b/repositories/CodeFormer/basicsr/losses/loss_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..744eeb46d1f3b5a7b4553ca23237ddd9c899a698
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/losses/loss_util.py
@@ -0,0 +1,95 @@
+import functools
+from torch.nn import functional as F
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are 'none', 'mean' and 'sum'.
+
+ Returns:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ else:
+ return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean'):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights. Default: None.
+ reduction (str): Same as built-in losses of PyTorch. Options are
+ 'none', 'mean' and 'sum'. Default: 'mean'.
+
+ Returns:
+ Tensor: Loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ assert weight.dim() == loss.dim()
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+ loss = loss * weight
+
+ # if weight is not specified or reduction is sum, just reduce the loss
+ if weight is None or reduction == 'sum':
+ loss = reduce_loss(loss, reduction)
+ # if reduction is mean, then compute mean over weight region
+ elif reduction == 'mean':
+ if weight.size(1) > 1:
+ weight = weight.sum()
+ else:
+ weight = weight.sum() * loss.size(1)
+ loss = loss.sum() / weight
+
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.5000)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, reduction='sum')
+ tensor(3.)
+ """
+
+ @functools.wraps(loss_func)
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction)
+ return loss
+
+ return wrapper
diff --git a/repositories/CodeFormer/basicsr/losses/losses.py b/repositories/CodeFormer/basicsr/losses/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bcf272cfb756d99451a3005567ea4d4c9059067
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/losses/losses.py
@@ -0,0 +1,455 @@
+import math
+import lpips
+import torch
+from torch import autograd as autograd
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.archs.vgg_arch import VGGFeatureExtractor
+from basicsr.utils.registry import LOSS_REGISTRY
+from .loss_util import weighted_loss
+
+_reduction_modes = ['none', 'mean', 'sum']
+
+
+@weighted_loss
+def l1_loss(pred, target):
+ return F.l1_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def mse_loss(pred, target):
+ return F.mse_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def charbonnier_loss(pred, target, eps=1e-12):
+ return torch.sqrt((pred - target)**2 + eps)
+
+
+@LOSS_REGISTRY.register()
+class L1Loss(nn.Module):
+ """L1 (mean absolute error, MAE) loss.
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean'):
+ super(L1Loss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class MSELoss(nn.Module):
+ """MSE (L2) loss.
+
+ Args:
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean'):
+ super(MSELoss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class CharbonnierLoss(nn.Module):
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
+ variant of L1Loss).
+
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
+ Super-Resolution".
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ eps (float): A value used to control the curvature near zero.
+ Default: 1e-12.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
+ super(CharbonnierLoss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+ self.eps = eps
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class WeightedTVLoss(L1Loss):
+ """Weighted TV loss.
+
+ Args:
+ loss_weight (float): Loss weight. Default: 1.0.
+ """
+
+ def __init__(self, loss_weight=1.0):
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
+
+ def forward(self, pred, weight=None):
+ y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
+ x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
+
+ loss = x_diff + y_diff
+
+ return loss
+
+
+@LOSS_REGISTRY.register()
+class PerceptualLoss(nn.Module):
+ """Perceptual loss with commonly used style loss.
+
+ Args:
+ layer_weights (dict): The weight for each layer of vgg feature.
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
+ feature layer (before relu5_4) will be extracted with weight
+ 1.0 in calculting losses.
+ vgg_type (str): The type of vgg network used as feature extractor.
+ Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image in vgg.
+ Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
+ loss will be calculated and the loss will multiplied by the
+ weight. Default: 1.0.
+ style_weight (float): If `style_weight > 0`, the style loss will be
+ calculated and the loss will multiplied by the weight.
+ Default: 0.
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
+ """
+
+ def __init__(self,
+ layer_weights,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ perceptual_weight=1.0,
+ style_weight=0.,
+ criterion='l1'):
+ super(PerceptualLoss, self).__init__()
+ self.perceptual_weight = perceptual_weight
+ self.style_weight = style_weight
+ self.layer_weights = layer_weights
+ self.vgg = VGGFeatureExtractor(
+ layer_name_list=list(layer_weights.keys()),
+ vgg_type=vgg_type,
+ use_input_norm=use_input_norm,
+ range_norm=range_norm)
+
+ self.criterion_type = criterion
+ if self.criterion_type == 'l1':
+ self.criterion = torch.nn.L1Loss()
+ elif self.criterion_type == 'l2':
+ self.criterion = torch.nn.L2loss()
+ elif self.criterion_type == 'mse':
+ self.criterion = torch.nn.MSELoss(reduction='mean')
+ elif self.criterion_type == 'fro':
+ self.criterion = None
+ else:
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
+
+ def forward(self, x, gt):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ # extract vgg features
+ x_features = self.vgg(x)
+ gt_features = self.vgg(gt.detach())
+
+ # calculate perceptual loss
+ if self.perceptual_weight > 0:
+ percep_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
+ else:
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
+ percep_loss *= self.perceptual_weight
+ else:
+ percep_loss = None
+
+ # calculate style loss
+ if self.style_weight > 0:
+ style_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ style_loss += torch.norm(
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
+ else:
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
+ gt_features[k])) * self.layer_weights[k]
+ style_loss *= self.style_weight
+ else:
+ style_loss = None
+
+ return percep_loss, style_loss
+
+ def _gram_mat(self, x):
+ """Calculate Gram matrix.
+
+ Args:
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
+
+ Returns:
+ torch.Tensor: Gram matrix.
+ """
+ n, c, h, w = x.size()
+ features = x.view(n, c, w * h)
+ features_t = features.transpose(1, 2)
+ gram = features.bmm(features_t) / (c * h * w)
+ return gram
+
+
+@LOSS_REGISTRY.register()
+class LPIPSLoss(nn.Module):
+ def __init__(self,
+ loss_weight=1.0,
+ use_input_norm=True,
+ range_norm=False,):
+ super(LPIPSLoss, self).__init__()
+ self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
+ self.loss_weight = loss_weight
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, pred, target):
+ if self.range_norm:
+ pred = (pred + 1) / 2
+ target = (target + 1) / 2
+ if self.use_input_norm:
+ pred = (pred - self.mean) / self.std
+ target = (target - self.mean) / self.std
+ lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
+ return self.loss_weight * lpips_loss.mean()
+
+
+@LOSS_REGISTRY.register()
+class GANLoss(nn.Module):
+ """Define GAN loss.
+
+ Args:
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
+ real_label_val (float): The value for real label. Default: 1.0.
+ fake_label_val (float): The value for fake label. Default: 0.0.
+ loss_weight (float): Loss weight. Default: 1.0.
+ Note that loss_weight is only for generators; and it is always 1.0
+ for discriminators.
+ """
+
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
+ super(GANLoss, self).__init__()
+ self.gan_type = gan_type
+ self.loss_weight = loss_weight
+ self.real_label_val = real_label_val
+ self.fake_label_val = fake_label_val
+
+ if self.gan_type == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif self.gan_type == 'lsgan':
+ self.loss = nn.MSELoss()
+ elif self.gan_type == 'wgan':
+ self.loss = self._wgan_loss
+ elif self.gan_type == 'wgan_softplus':
+ self.loss = self._wgan_softplus_loss
+ elif self.gan_type == 'hinge':
+ self.loss = nn.ReLU()
+ else:
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
+
+ def _wgan_loss(self, input, target):
+ """wgan loss.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return -input.mean() if target else input.mean()
+
+ def _wgan_softplus_loss(self, input, target):
+ """wgan loss with soft plus. softplus is a smooth approximation to the
+ ReLU function.
+
+ In StyleGAN2, it is called:
+ Logistic loss for discriminator;
+ Non-saturating loss for generator.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
+
+ def get_target_label(self, input, target_is_real):
+ """Get target label.
+
+ Args:
+ input (Tensor): Input tensor.
+ target_is_real (bool): Whether the target is real or fake.
+
+ Returns:
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
+ return Tensor.
+ """
+
+ if self.gan_type in ['wgan', 'wgan_softplus']:
+ return target_is_real
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
+ return input.new_ones(input.size()) * target_val
+
+ def forward(self, input, target_is_real, is_disc=False):
+ """
+ Args:
+ input (Tensor): The input for the loss module, i.e., the network
+ prediction.
+ target_is_real (bool): Whether the targe is real or fake.
+ is_disc (bool): Whether the loss for discriminators or not.
+ Default: False.
+
+ Returns:
+ Tensor: GAN loss value.
+ """
+ if self.gan_type == 'hinge':
+ if is_disc: # for discriminators in hinge-gan
+ input = -input if target_is_real else input
+ loss = self.loss(1 + input).mean()
+ else: # for generators in hinge-gan
+ loss = -input.mean()
+ else: # other gan types
+ target_label = self.get_target_label(input, target_is_real)
+ loss = self.loss(input, target_label)
+
+ # loss_weight is always 1.0 for discriminators
+ return loss if is_disc else loss * self.loss_weight
+
+
+def r1_penalty(real_pred, real_img):
+ """R1 regularization for discriminator. The core idea is to
+ penalize the gradient on real data alone: when the
+ generator distribution produces the true data distribution
+ and the discriminator is equal to 0 on the data manifold, the
+ gradient penalty ensures that the discriminator cannot create
+ a non-zero gradient orthogonal to the data manifold without
+ suffering a loss in the GAN game.
+
+ Ref:
+ Eq. 9 in Which training methods for GANs do actually converge.
+ """
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
+ return grad_penalty
+
+
+def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
+
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
+
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
+
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
+
+
+def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
+ """Calculate gradient penalty for wgan-gp.
+
+ Args:
+ discriminator (nn.Module): Network for the discriminator.
+ real_data (Tensor): Real input data.
+ fake_data (Tensor): Fake input data.
+ weight (Tensor): Weight tensor. Default: None.
+
+ Returns:
+ Tensor: A tensor for gradient penalty.
+ """
+
+ batch_size = real_data.size(0)
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
+
+ # interpolate between real_data and fake_data
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
+
+ disc_interpolates = discriminator(interpolates)
+ gradients = autograd.grad(
+ outputs=disc_interpolates,
+ inputs=interpolates,
+ grad_outputs=torch.ones_like(disc_interpolates),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True)[0]
+
+ if weight is not None:
+ gradients = gradients * weight
+
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
+ if weight is not None:
+ gradients_penalty /= torch.mean(weight)
+
+ return gradients_penalty
diff --git a/repositories/CodeFormer/basicsr/metrics/__init__.py b/repositories/CodeFormer/basicsr/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..19d55cc8321f124c918d78465b053aef67f13a33
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/metrics/__init__.py
@@ -0,0 +1,19 @@
+from copy import deepcopy
+
+from basicsr.utils.registry import METRIC_REGISTRY
+from .psnr_ssim import calculate_psnr, calculate_ssim
+
+__all__ = ['calculate_psnr', 'calculate_ssim']
+
+
+def calculate_metric(data, opt):
+ """Calculate metric from data and options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ metric_type = opt.pop('type')
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
+ return metric
diff --git a/repositories/CodeFormer/basicsr/metrics/metric_util.py b/repositories/CodeFormer/basicsr/metrics/metric_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d18f0f7816431bed6af9d58319c6435bdf5c971
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/metrics/metric_util.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+from basicsr.utils.matlab_functions import bgr2ycbcr
+
+
+def reorder_image(img, input_order='HWC'):
+ """Reorder images to 'HWC' order.
+
+ If the input_order is (h, w), return (h, w, 1);
+ If the input_order is (c, h, w), return (h, w, c);
+ If the input_order is (h, w, c), return as it is.
+
+ Args:
+ img (ndarray): Input image.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ If the input image shape is (h, w), input_order will not have
+ effects. Default: 'HWC'.
+
+ Returns:
+ ndarray: reordered image.
+ """
+
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
+ if len(img.shape) == 2:
+ img = img[..., None]
+ if input_order == 'CHW':
+ img = img.transpose(1, 2, 0)
+ return img
+
+
+def to_y_channel(img):
+ """Change to Y channel of YCbCr.
+
+ Args:
+ img (ndarray): Images with range [0, 255].
+
+ Returns:
+ (ndarray): Images with range [0, 255] (float type) without round.
+ """
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 3 and img.shape[2] == 3:
+ img = bgr2ycbcr(img, y_only=True)
+ img = img[..., None]
+ return img * 255.
diff --git a/repositories/CodeFormer/basicsr/metrics/psnr_ssim.py b/repositories/CodeFormer/basicsr/metrics/psnr_ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbd950699c2495880236883861d9e199f900eae8
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/metrics/psnr_ssim.py
@@ -0,0 +1,128 @@
+import cv2
+import numpy as np
+
+from basicsr.metrics.metric_util import reorder_image, to_y_channel
+from basicsr.utils.registry import METRIC_REGISTRY
+
+
+@METRIC_REGISTRY.register()
+def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
+
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the PSNR calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: psnr result.
+ """
+
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
+ img1 = reorder_image(img1, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ if crop_border != 0:
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img1 = to_y_channel(img1)
+ img2 = to_y_channel(img2)
+
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20. * np.log10(255. / np.sqrt(mse))
+
+
+def _ssim(img1, img2):
+ """Calculate SSIM (structural similarity) for one channel images.
+
+ It is called by func:`calculate_ssim`.
+
+ Args:
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
+
+ Returns:
+ float: ssim result.
+ """
+
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+@METRIC_REGISTRY.register()
+def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
+ """Calculate SSIM (structural similarity).
+
+ Ref:
+ Image quality assessment: From error visibility to structural similarity
+
+ The results are the same as that of the official released MATLAB code in
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
+
+ For three-channel images, SSIM is calculated for each channel and then
+ averaged.
+
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the SSIM calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: ssim result.
+ """
+
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
+ img1 = reorder_image(img1, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ if crop_border != 0:
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img1 = to_y_channel(img1)
+ img2 = to_y_channel(img2)
+
+ ssims = []
+ for i in range(img1.shape[2]):
+ ssims.append(_ssim(img1[..., i], img2[..., i]))
+ return np.array(ssims).mean()
diff --git a/repositories/CodeFormer/basicsr/models/__init__.py b/repositories/CodeFormer/basicsr/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..00bde45f003698a5b15d3517ae47b59ef1d86e0c
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/models/__init__.py
@@ -0,0 +1,30 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import MODEL_REGISTRY
+
+__all__ = ['build_model']
+
+# automatically scan and import model modules for registry
+# scan all the files under the 'models' folder and collect files ending with
+# '_model.py'
+model_folder = osp.dirname(osp.abspath(__file__))
+model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
+# import all the model modules
+_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
+
+
+def build_model(opt):
+ """Build model from options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ model_type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ model = MODEL_REGISTRY.get(opt['model_type'])(opt)
+ logger = get_root_logger()
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
+ return model
diff --git a/repositories/CodeFormer/basicsr/ops/__init__.py b/repositories/CodeFormer/basicsr/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/CodeFormer/basicsr/ops/dcn/__init__.py b/repositories/CodeFormer/basicsr/ops/dcn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..32e3592f896d61b4127e09d0476381b9d55e32ff
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/dcn/__init__.py
@@ -0,0 +1,7 @@
+from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
+ modulated_deform_conv)
+
+__all__ = [
+ 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
+ 'modulated_deform_conv'
+]
diff --git a/repositories/CodeFormer/basicsr/ops/dcn/deform_conv.py b/repositories/CodeFormer/basicsr/ops/dcn/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..734154f9ed9447d585eae7df6886acb136f8a3cf
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/dcn/deform_conv.py
@@ -0,0 +1,377 @@
+import math
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn import functional as F
+from torch.nn.modules.utils import _pair, _single
+
+try:
+ from . import deform_conv_ext
+except ImportError:
+ import os
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
+ if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ deform_conv_ext = load(
+ 'deform_conv',
+ sources=[
+ os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
+ os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
+ os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
+ ],
+ )
+
+
+class DeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ weight,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ im2col_step=64):
+ if input is not None and input.dim() != 4:
+ raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.')
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.im2col_step = im2col_step
+
+ ctx.save_for_backward(input, offset, weight)
+
+ output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
+
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+
+ if not input.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+ deform_conv_ext.deform_conv_forward(input, weight,
+ offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, cur_im2col_step)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+
+ grad_input = grad_offset = grad_weight = None
+
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
+ grad_offset, weight, ctx.bufs_[0], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, cur_im2col_step)
+
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
+ ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0],
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
+ cur_im2col_step)
+
+ return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
+
+ @staticmethod
+ def _output_size(input, weight, padding, dilation, stride):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = padding[d]
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})')
+ return output_size
+
+
+class ModulatedDeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ mask,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1):
+ ctx.stride = stride
+ ctx.padding = padding
+ ctx.dilation = dilation
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(1) # fake tensor
+ if not input.is_cuda:
+ raise NotImplementedError
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad \
+ or input.requires_grad:
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
+ ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
+ grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
+ grad_output, weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ if not ctx.with_bias:
+ grad_bias = None
+
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
+
+ @staticmethod
+ def _infer_shape(ctx, input, weight):
+ n = input.size(0)
+ channels_out = weight.size(0)
+ height, width = input.shape[2:4]
+ kernel_h, kernel_w = weight.shape[2:4]
+ height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
+ width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
+ return n, channels_out, height_out, width_out
+
+
+deform_conv = DeformConvFunction.apply
+modulated_deform_conv = ModulatedDeformConvFunction.apply
+
+
+class DeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=False):
+ super(DeformConv, self).__init__()
+
+ assert not bias
+ assert in_channels % groups == 0, \
+ f'in_channels {in_channels} is not divisible by groups {groups}'
+ assert out_channels % groups == 0, \
+ f'out_channels {out_channels} is not divisible ' \
+ f'by groups {groups}'
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+
+ def forward(self, x, offset):
+ # To fix an assert error in deform_conv_cuda.cpp:128
+ # input image is smaller than kernel
+ input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
+ if input_pad:
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+ self.deformable_groups)
+ if input_pad:
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
+ return out
+
+
+class DeformConvPack(DeformConv):
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(DeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+ self.deformable_groups)
+
+
+class ModulatedDeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=True):
+ super(ModulatedDeformConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ self.with_bias = bias
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.init_weights()
+
+ def init_weights(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
+
+
+class ModulatedDeformConvPack(ModulatedDeformConv):
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ super(ModulatedDeformConvPack, self).init_weights()
+ if hasattr(self, 'conv_offset'):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ out = self.conv_offset(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
diff --git a/repositories/CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/repositories/CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5d9424908ed2dbd4ac3cdb98d13e09287a4d2f2d
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda.cpp
@@ -0,0 +1,685 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include
+#include
+
+#include
+#include
+
+void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor data_col);
+
+void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im);
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const int channels, const int height,
+ const int width, const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor grad_offset);
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor data_col);
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor grad_im);
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im,
+ const int width_im, const int height_col, const int width_col,
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w, const int dilation_h,
+ const int dilation_w, const int deformable_group, at::Tensor grad_offset,
+ at::Tensor grad_mask);
+
+void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
+ at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
+ int padW, int dilationH, int dilationW, int group,
+ int deformable_group) {
+ TORCH_CHECK(weight.ndimension() == 4,
+ "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
+ "but got: %s",
+ weight.ndimension());
+
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+ TORCH_CHECK(kW > 0 && kH > 0,
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
+ kW);
+
+ TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
+ "kernel size should be consistent with weight, ",
+ "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
+ kW, weight.size(2), weight.size(3));
+
+ TORCH_CHECK(dW > 0 && dH > 0,
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+
+ TORCH_CHECK(
+ dilationW > 0 && dilationH > 0,
+ "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
+ dilationH, dilationW);
+
+ int ndim = input.ndimension();
+ int dimf = 0;
+ int dimh = 1;
+ int dimw = 2;
+
+ if (ndim == 4) {
+ dimf++;
+ dimh++;
+ dimw++;
+ }
+
+ TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
+ ndim);
+
+ long nInputPlane = weight.size(1) * group;
+ long inputHeight = input.size(dimh);
+ long inputWidth = input.size(dimw);
+ long nOutputPlane = weight.size(0);
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+
+ TORCH_CHECK(nInputPlane % deformable_group == 0,
+ "input channels must divide deformable group size");
+
+ if (outputWidth < 1 || outputHeight < 1)
+ AT_ERROR(
+ "Given input size: (%ld x %ld x %ld). "
+ "Calculated output size: (%ld x %ld x %ld). Output size is too small",
+ nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
+ outputWidth);
+
+ TORCH_CHECK(input.size(1) == nInputPlane,
+ "invalid number of input planes, expected: %d, but got: %d",
+ nInputPlane, input.size(1));
+
+ TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
+ "input image is smaller than kernel");
+
+ TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
+ "invalid spatial size of offset, expected height: %d width: %d, but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, offset.size(2), offset.size(3));
+
+ TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
+ "invalid number of channels of offset");
+
+ if (gradOutput != NULL) {
+ TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
+ "invalid number of gradOutput planes, expected: %d, but got: %d",
+ nOutputPlane, gradOutput->size(dimf));
+
+ TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
+ gradOutput->size(dimw) == outputWidth),
+ "invalid size of gradOutput, expected height: %d width: %d , but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, gradOutput->size(dimh),
+ gradOutput->size(dimw));
+ }
+}
+
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ // todo: resize columns to include im2col: done
+ // todo: add im2col_step as input
+ // todo: add new output buffer and transpose it to output (or directly
+ // transpose output) todo: possibly change data indexing because of
+ // parallel_imgs
+
+ shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input.unsqueeze_(0);
+ offset.unsqueeze_(0);
+ }
+
+ // todo: assert batchsize dividable by im2col_step
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
+ outputHeight, outputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
+ ones = at::ones({outputHeight, outputWidth}, input.options());
+ }
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ at::Tensor output_buffer =
+ at::zeros({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth},
+ output.options());
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), group, output_buffer.size(1) / group,
+ output_buffer.size(2), output_buffer.size(3)});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ output_buffer[elt][g] = output_buffer[elt][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output_buffer[elt][g]);
+ }
+ }
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
+ output_buffer.size(3), output_buffer.size(4)});
+
+ output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step, outputHeight, outputWidth});
+ output_buffer.transpose_(1, 2);
+ output.copy_(output_buffer);
+ output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ output = output.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view({1, input.size(0), input.size(1), input.size(2)});
+ offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ // change order of grad output
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight,
+ outputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ // divide into groups
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), group, gradOutput.size(1) / group,
+ gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
+ gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
+
+ deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
+ inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW, im2col_step, deformable_group,
+ gradOffset[elt]);
+
+ deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, gradInput[elt]);
+ }
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ gradOffset = gradOffset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ gradOffset =
+ gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_parameters_cuda(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step) {
+ // todo: transpose and reshape outGrad
+ // todo: reshape columns
+ // todo: add im2col_step as input
+
+ shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
+ padW, dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view(
+ at::IntList({1, input.size(0), input.size(1), input.size(2)}));
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = gradWeight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
+ outputHeight, outputWidth});
+ gradOutputBuffer.copy_(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth});
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ // divide into group
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
+ gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ gradWeight =
+ gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ gradWeight[g] = gradWeight[g]
+ .flatten(1)
+ .addmm_(gradOutputBuffer[elt][g].flatten(1),
+ columns[g].transpose(1, 0), 1.0, scale)
+ .view_as(gradWeight[g]);
+ }
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0),
+ gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
+ gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3),
+ gradWeight.size(4)});
+ }
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ }
+
+ return 1;
+}
+
+void modulated_deform_conv_cuda_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_out = weight.size(0);
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ // resize output
+ output = output.view({batch, channels_out, height_out, width_out}).zero_();
+ // resize temporary columns
+ columns =
+ at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
+ input.options());
+
+ output = output.view({output.size(0), group, output.size(1) / group,
+ output.size(2), output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ // divide into group
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+
+ for (int g = 0; g < group; g++) {
+ output[b][g] = output[b][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output[b][g]);
+ }
+
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ }
+
+ output = output.view({output.size(0), output.size(1) * output.size(2),
+ output.size(3), output.size(4)});
+
+ if (with_bias) {
+ output += bias.view({1, bias.size(0), 1, 1});
+ }
+}
+
+void modulated_deform_conv_cuda_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ grad_input = grad_input.view({batch, channels, height, width});
+ columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
+ input.options());
+
+ grad_output =
+ grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
+ grad_output.size(2), grad_output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ // divide int group
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ grad_output[b][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+
+ // gradient w.r.t. input coordinate data
+ modulated_deformable_col2im_coord_cuda(
+ columns, input[b], offset[b], mask[b], 1, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+ stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
+ grad_mask[b]);
+ // gradient w.r.t. input data
+ modulated_deformable_col2im_cuda(
+ columns, offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, grad_input[b]);
+
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and
+ // group
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
+ grad_weight.size(1), grad_weight.size(2),
+ grad_weight.size(3)});
+ if (with_bias)
+ grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
+
+ for (int g = 0; g < group; g++) {
+ grad_weight[g] =
+ grad_weight[g]
+ .flatten(1)
+ .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
+ .view_as(grad_weight[g]);
+ if (with_bias) {
+ grad_bias[g] =
+ grad_bias[g]
+ .view({-1, 1})
+ .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
+ .view(-1);
+ }
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
+ grad_weight.size(2), grad_weight.size(3),
+ grad_weight.size(4)});
+ if (with_bias)
+ grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
+ }
+ grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
+ grad_output.size(2), grad_output.size(3),
+ grad_output.size(4)});
+}
diff --git a/repositories/CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/repositories/CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..98752dccf8c58817ca1a952554dd3f33188a2d34
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
@@ -0,0 +1,867 @@
+/*!
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
+ *
+ * COPYRIGHT
+ *
+ * All contributions by the University of California:
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+ * All rights reserved.
+ *
+ * All other contributions:
+ * Copyright (c) 2014-2017, the respective contributors
+ * All rights reserved.
+ *
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ *
+ * LICENSE
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * CONTRIBUTION AGREEMENT
+ *
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ *
+ ***************** END Caffe Copyright Notice and Disclaimer ********************
+ *
+ * Copyright (c) 2018 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file modulated_deformable_im2col.cuh
+ * \brief Function definitions of converting an image to
+ * column matrix based on kernel, padding, dilation, and offset.
+ * These functions are mainly used in deformable convolution operators.
+ * \ref: https://arxiv.org/abs/1703.06211
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
+ */
+
+// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+using namespace at;
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+const int kMaxGridNum = 65535;
+
+inline int GET_BLOCKS(const int N)
+{
+ return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
+}
+
+template
+__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+ const int height, const int width, scalar_t h, scalar_t w)
+{
+
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ scalar_t lh = h - h_low;
+ scalar_t lw = w - w_low;
+ scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int height, const int width, const scalar_t *im_data,
+ const int data_width, const int bp_dir)
+{
+
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ scalar_t val = static_cast(0);
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const scalar_t map_h = i * dilation_h + offset_h;
+ //const scalar_t map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val;
+ data_col_ptr += batch_size * height_col * width_col;
+ }
+ }
+ }
+}
+
+void deformable_im2col(
+ const at::Tensor data_im, const at::Tensor data_offset, const int channels,
+ const int height, const int width, const int ksize_h, const int ksize_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor data_col)
+{
+ // num_axes should be smaller than block size
+ // todo: check parallel_imgs is correctly passed in
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * height_col * width_col * parallel_imgs;
+ int channel_per_deformable_group = channels / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ scalar_t *data_col_ = data_col.data_ptr();
+
+ deformable_im2col_gpu_kernel<<>>(
+ num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ channel_per_deformable_group, parallel_imgs, channels, deformable_group,
+ height_col, width_col, data_col_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+__global__ void deformable_col2im_gpu_kernel(
+ const int n, const scalar_t *data_col, const scalar_t *data_offset,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
+ 2 * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const scalar_t cur_top_grad = data_col[index];
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+void deformable_col2im(
+ const at::Tensor data_col, const at::Tensor data_offset, const int channels,
+ const int height, const int width, const int ksize_h,
+ const int ksize_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im)
+{
+
+ // todo: make sure parallel_imgs is passed in correctly
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
+ int channel_per_deformable_group = channels / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ scalar_t *grad_im_ = grad_im.data_ptr();
+
+ deformable_col2im_gpu_kernel<<>>(
+ num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
+ ksize_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ parallel_imgs, deformable_group, height_col, width_col, grad_im_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
+ const scalar_t *data_im, const scalar_t *data_offset,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col, scalar_t *grad_offset)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ scalar_t val = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
+ batch_size * width_col * height_col;
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
+ channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ const scalar_t weight = get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos];
+ cnt += 1;
+ }
+
+ grad_offset[index] = val;
+ }
+}
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width, const int ksize_h,
+ const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
+ const int stride_w, const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
+{
+
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
+ int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ scalar_t *grad_offset_ = grad_offset.data_ptr();
+
+ deformable_col2im_coord_gpu_kernel<<>>(
+ num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
+ ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
+ height_col, width_col, grad_offset_);
+ }));
+}
+
+template
+__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+ const int height, const int width, scalar_t h, scalar_t w)
+{
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ scalar_t lh = h - h_low;
+ scalar_t lw = w - w_low;
+ scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int height, const int width, const scalar_t *im_data,
+ const int data_width, const int bp_dir)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ scalar_t val = static_cast(0);
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
+ //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const float map_h = i * dilation_h + offset_h;
+ //const float map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val * mask;
+ data_col_ptr += batch_size * height_col * width_col;
+ //data_col_ptr += height_col * width_col;
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
+ const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const scalar_t cur_top_grad = data_col[index] * mask;
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
+ const scalar_t *data_col, const scalar_t *data_im,
+ const scalar_t *data_offset, const scalar_t *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_offset, scalar_t *grad_mask)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ scalar_t val = 0, mval = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ else
+ {
+ mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
+ }
+ const scalar_t weight = dmcn_get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos] * mask;
+ cnt += 1;
+ }
+ // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
+ grad_offset[index] = val;
+ if (offset_c % 2 == 0)
+ // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
+ grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
+ }
+}
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, at::Tensor data_col)
+{
+ // num_axes should be smaller than block size
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * batch_size * height_col * width_col;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ const scalar_t *data_mask_ = data_mask.data_ptr();
+ scalar_t *data_col_ = data_col.data_ptr();
+
+ modulated_deformable_im2col_gpu_kernel<<>>(
+ num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, channels, deformable_group, height_col, width_col, data_col_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, at::Tensor grad_im)
+{
+
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ const scalar_t *data_mask_ = data_mask.data_ptr();
+ scalar_t *grad_im_ = grad_im.data_ptr();
+
+ modulated_deformable_col2im_gpu_kernel<<>>(
+ num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, deformable_group, height_col, width_col, grad_im_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group,
+ at::Tensor grad_offset, at::Tensor grad_mask)
+{
+ const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
+ const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ const scalar_t *data_mask_ = data_mask.data_ptr();
+ scalar_t *grad_offset_ = grad_offset.data_ptr();
+ scalar_t *grad_mask_ = grad_mask.data_ptr();
+
+ modulated_deformable_col2im_coord_gpu_kernel<<>>(
+ num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
+ grad_offset_, grad_mask_);
+ }));
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
diff --git a/repositories/CodeFormer/basicsr/ops/dcn/src/deform_conv_ext.cpp b/repositories/CodeFormer/basicsr/ops/dcn/src/deform_conv_ext.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..41c6df6f721bd95a525fd6a03dd9882e863de042
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/dcn/src/deform_conv_ext.cpp
@@ -0,0 +1,164 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include
+#include
+
+#include
+#include
+
+#define WITH_CUDA // always use cuda
+#ifdef WITH_CUDA
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step);
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step);
+
+int deform_conv_backward_parameters_cuda(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step);
+
+void modulated_deform_conv_cuda_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias);
+
+void modulated_deform_conv_cuda_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias);
+#endif
+
+int deform_conv_forward(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return deform_conv_forward_cuda(input, weight, offset, output, columns,
+ ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
+ deformable_group, im2col_step);
+#else
+ AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("deform conv is not implemented on CPU");
+}
+
+int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return deform_conv_backward_input_cuda(input, offset, gradOutput,
+ gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
+ dilationW, dilationH, group, deformable_group, im2col_step);
+#else
+ AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("deform conv is not implemented on CPU");
+}
+
+int deform_conv_backward_parameters(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
+ gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
+ dilationH, group, deformable_group, scale, im2col_step);
+#else
+ AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("deform conv is not implemented on CPU");
+}
+
+void modulated_deform_conv_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
+ offset, mask, output, columns, kernel_h, kernel_w, stride_h,
+ stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
+ deformable_group, with_bias);
+#else
+ AT_ERROR("modulated deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("modulated deform conv is not implemented on CPU");
+}
+
+void modulated_deform_conv_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
+ offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
+ grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
+ pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
+ with_bias);
+#else
+ AT_ERROR("modulated deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("modulated deform conv is not implemented on CPU");
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("deform_conv_forward", &deform_conv_forward,
+ "deform forward");
+ m.def("deform_conv_backward_input", &deform_conv_backward_input,
+ "deform_conv_backward_input");
+ m.def("deform_conv_backward_parameters",
+ &deform_conv_backward_parameters,
+ "deform_conv_backward_parameters");
+ m.def("modulated_deform_conv_forward",
+ &modulated_deform_conv_forward,
+ "modulated deform conv forward");
+ m.def("modulated_deform_conv_backward",
+ &modulated_deform_conv_backward,
+ "modulated deform conv backward");
+}
diff --git a/repositories/CodeFormer/basicsr/ops/fused_act/__init__.py b/repositories/CodeFormer/basicsr/ops/fused_act/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..241dc0754fae7d88dbbd9a02e665ca30a73c7422
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/fused_act/__init__.py
@@ -0,0 +1,3 @@
+from .fused_act import FusedLeakyReLU, fused_leaky_relu
+
+__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
diff --git a/repositories/CodeFormer/basicsr/ops/fused_act/fused_act.py b/repositories/CodeFormer/basicsr/ops/fused_act/fused_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..588f815e596ab0fc83ab0f9d21426c22ec5ed7c3
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/fused_act/fused_act.py
@@ -0,0 +1,89 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+
+import torch
+from torch import nn
+from torch.autograd import Function
+
+try:
+ from . import fused_act_ext
+except ImportError:
+ import os
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
+ if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ fused_act_ext = load(
+ 'fused',
+ sources=[
+ os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
+ os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
+ ],
+ )
+
+
+class FusedLeakyReLUFunctionBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ out, = ctx.saved_tensors
+ gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
+ ctx.scale)
+
+ return gradgrad_out, None, None, None
+
+
+class FusedLeakyReLUFunction(Function):
+
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+ out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ out, = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedLeakyReLU(nn.Module):
+
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
+ super().__init__()
+
+ self.bias = nn.Parameter(torch.zeros(channel))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/repositories/CodeFormer/basicsr/ops/fused_act/src/fused_bias_act.cpp b/repositories/CodeFormer/basicsr/ops/fused_act/src/fused_bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..85ed0a79fb9c75f83470ac834090f03608d998ee
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/fused_act/src/fused_bias_act.cpp
@@ -0,0 +1,26 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
+#include
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input,
+ const torch::Tensor& bias,
+ const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor fused_bias_act(const torch::Tensor& input,
+ const torch::Tensor& bias,
+ const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(bias);
+
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
+}
diff --git a/repositories/CodeFormer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/repositories/CodeFormer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..54c7ff53ce8306db2b3c582ec7fa6696a38b4df0
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
@@ -0,0 +1,100 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+template
+static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
+
+ scalar_t zero = 0.0;
+
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
+ scalar_t x = p_x[xi];
+
+ if (use_bias) {
+ x += p_b[(xi / step_b) % size_b];
+ }
+
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
+
+ scalar_t y;
+
+ switch (act * 10 + grad) {
+ default:
+ case 10: y = x; break;
+ case 11: y = x; break;
+ case 12: y = 0.0; break;
+
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
+ case 32: y = 0.0; break;
+ }
+
+ out[xi] = y * scale;
+ }
+}
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ auto x = input.contiguous();
+ auto b = bias.contiguous();
+ auto ref = refer.contiguous();
+
+ int use_bias = b.numel() ? 1 : 0;
+ int use_ref = ref.numel() ? 1 : 0;
+
+ int size_x = x.numel();
+ int size_b = b.numel();
+ int step_b = 1;
+
+ for (int i = 1 + 1; i < x.dim(); i++) {
+ step_b *= x.size(i);
+ }
+
+ int loop_x = 4;
+ int block_size = 4 * 32;
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
+
+ auto y = torch::empty_like(x);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
+ fused_bias_act_kernel<<>>(
+ y.data_ptr(),
+ x.data_ptr(),
+ b.data_ptr(),
+ ref.data_ptr(),
+ act,
+ grad,
+ alpha,
+ scale,
+ loop_x,
+ size_x,
+ step_b,
+ size_b,
+ use_bias,
+ use_ref
+ );
+ });
+
+ return y;
+}
diff --git a/repositories/CodeFormer/basicsr/ops/upfirdn2d/__init__.py b/repositories/CodeFormer/basicsr/ops/upfirdn2d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..397e85bea063e97fc4c12ad4d3e15669b69290bd
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/upfirdn2d/__init__.py
@@ -0,0 +1,3 @@
+from .upfirdn2d import upfirdn2d
+
+__all__ = ['upfirdn2d']
diff --git a/repositories/CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/repositories/CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..43d0b6783a5b512b55815a291fcac2bebeea31e0
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
@@ -0,0 +1,24 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
+#include
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(kernel);
+
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
+}
diff --git a/repositories/CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/repositories/CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8870063bae4468deab2e721f0978fe9facfb01b1
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
@@ -0,0 +1,370 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
+ int c = a / b;
+
+ if (c * b > a) {
+ c--;
+ }
+
+ return c;
+}
+
+struct UpFirDn2DKernelParams {
+ int up_x;
+ int up_y;
+ int down_x;
+ int down_y;
+ int pad_x0;
+ int pad_x1;
+ int pad_y0;
+ int pad_y1;
+
+ int major_dim;
+ int in_h;
+ int in_w;
+ int minor_dim;
+ int kernel_h;
+ int kernel_w;
+ int out_h;
+ int out_w;
+ int loop_major;
+ int loop_x;
+};
+
+template
+__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
+ const scalar_t *kernel,
+ const UpFirDn2DKernelParams p) {
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ int out_y = minor_idx / p.minor_dim;
+ minor_idx -= out_y * p.minor_dim;
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
+ major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
+
+ for (int loop_major = 0, major_idx = major_idx_base;
+ loop_major < p.loop_major && major_idx < p.major_dim;
+ loop_major++, major_idx++) {
+ for (int loop_x = 0, out_x = out_x_base;
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
+
+ const scalar_t *x_p =
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
+ minor_idx];
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
+ int x_px = p.minor_dim;
+ int k_px = -p.up_x;
+ int x_py = p.in_w * p.minor_dim;
+ int k_py = -p.up_y * p.kernel_w;
+
+ scalar_t v = 0.0f;
+
+ for (int y = 0; y < h; y++) {
+ for (int x = 0; x < w; x++) {
+ v += static_cast(*x_p) * static_cast(*k_p);
+ x_p += x_px;
+ k_p += k_px;
+ }
+
+ x_p += x_py - w * x_px;
+ k_p += k_py - w * k_px;
+ }
+
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+ minor_idx] = v;
+ }
+ }
+}
+
+template
+__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
+ const scalar_t *kernel,
+ const UpFirDn2DKernelParams p) {
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
+
+ __shared__ volatile float sk[kernel_h][kernel_w];
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
+
+ int minor_idx = blockIdx.x;
+ int tile_out_y = minor_idx / p.minor_dim;
+ minor_idx -= tile_out_y * p.minor_dim;
+ tile_out_y *= tile_out_h;
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
+ major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
+ tap_idx += blockDim.x) {
+ int ky = tap_idx / kernel_w;
+ int kx = tap_idx - ky * kernel_w;
+ scalar_t v = 0.0;
+
+ if (kx < p.kernel_w & ky < p.kernel_h) {
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
+ }
+
+ sk[ky][kx] = v;
+ }
+
+ for (int loop_major = 0, major_idx = major_idx_base;
+ loop_major < p.loop_major & major_idx < p.major_dim;
+ loop_major++, major_idx++) {
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
+ loop_x < p.loop_x & tile_out_x < p.out_w;
+ loop_x++, tile_out_x += tile_out_w) {
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
+ int tile_in_x = floor_div(tile_mid_x, up_x);
+ int tile_in_y = floor_div(tile_mid_y, up_y);
+
+ __syncthreads();
+
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
+ in_idx += blockDim.x) {
+ int rel_in_y = in_idx / tile_in_w;
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
+ int in_x = rel_in_x + tile_in_x;
+ int in_y = rel_in_y + tile_in_y;
+
+ scalar_t v = 0.0;
+
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
+ p.minor_dim +
+ minor_idx];
+ }
+
+ sx[rel_in_y][rel_in_x] = v;
+ }
+
+ __syncthreads();
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
+ out_idx += blockDim.x) {
+ int rel_out_y = out_idx / tile_out_w;
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
+ int out_x = rel_out_x + tile_out_x;
+ int out_y = rel_out_y + tile_out_y;
+
+ int mid_x = tile_mid_x + rel_out_x * down_x;
+ int mid_y = tile_mid_y + rel_out_y * down_y;
+ int in_x = floor_div(mid_x, up_x);
+ int in_y = floor_div(mid_y, up_y);
+ int rel_in_x = in_x - tile_in_x;
+ int rel_in_y = in_y - tile_in_y;
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
+
+ scalar_t v = 0.0;
+
+#pragma unroll
+ for (int y = 0; y < kernel_h / up_y; y++)
+#pragma unroll
+ for (int x = 0; x < kernel_w / up_x; x++)
+ v += sx[rel_in_y + y][rel_in_x + x] *
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
+
+ if (out_x < p.out_w & out_y < p.out_h) {
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+ minor_idx] = v;
+ }
+ }
+ }
+ }
+}
+
+torch::Tensor upfirdn2d_op(const torch::Tensor &input,
+ const torch::Tensor &kernel, int up_x, int up_y,
+ int down_x, int down_y, int pad_x0, int pad_x1,
+ int pad_y0, int pad_y1) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ UpFirDn2DKernelParams p;
+
+ auto x = input.contiguous();
+ auto k = kernel.contiguous();
+
+ p.major_dim = x.size(0);
+ p.in_h = x.size(1);
+ p.in_w = x.size(2);
+ p.minor_dim = x.size(3);
+ p.kernel_h = k.size(0);
+ p.kernel_w = k.size(1);
+ p.up_x = up_x;
+ p.up_y = up_y;
+ p.down_x = down_x;
+ p.down_y = down_y;
+ p.pad_x0 = pad_x0;
+ p.pad_x1 = pad_x1;
+ p.pad_y0 = pad_y0;
+ p.pad_y1 = pad_y1;
+
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
+ p.down_y;
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
+ p.down_x;
+
+ auto out =
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
+
+ int mode = -1;
+
+ int tile_out_h = -1;
+ int tile_out_w = -1;
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 1;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
+ mode = 2;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 3;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 4;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 5;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 6;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ dim3 block_size;
+ dim3 grid_size;
+
+ if (tile_out_h > 0 && tile_out_w > 0) {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 1;
+ block_size = dim3(32 * 8, 1, 1);
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ } else {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 4;
+ block_size = dim3(4, 32, 1);
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
+ switch (mode) {
+ case 1:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 2:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 3:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 4:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 5:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 6:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ default:
+ upfirdn2d_kernel_large<<>>(
+ out.data_ptr(), x.data_ptr(),
+ k.data_ptr(), p);
+ }
+ });
+
+ return out;
+}
diff --git a/repositories/CodeFormer/basicsr/ops/upfirdn2d/upfirdn2d.py b/repositories/CodeFormer/basicsr/ops/upfirdn2d/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..667f96e1ded35d48f163f37e21d1ed8ff191aac3
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/ops/upfirdn2d/upfirdn2d.py
@@ -0,0 +1,186 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
+
+import torch
+from torch.autograd import Function
+from torch.nn import functional as F
+
+try:
+ from . import upfirdn2d_ext
+except ImportError:
+ import os
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
+ if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ upfirdn2d_ext = load(
+ 'upfirdn2d',
+ sources=[
+ os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
+ os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
+ ],
+ )
+
+
+class UpFirDn2dBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
+
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+ grad_input = upfirdn2d_ext.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ down_x,
+ down_y,
+ up_x,
+ up_y,
+ g_pad_x0,
+ g_pad_x1,
+ g_pad_y0,
+ g_pad_y1,
+ )
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
+
+ ctx.save_for_backward(kernel)
+
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ kernel, = ctx.saved_tensors
+
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
+
+ gradgrad_out = upfirdn2d_ext.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ ctx.up_x,
+ ctx.up_y,
+ ctx.down_x,
+ ctx.down_y,
+ ctx.pad_x0,
+ ctx.pad_x1,
+ ctx.pad_y0,
+ ctx.pad_y1,
+ )
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+ # ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
+
+ return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ kernel_h, kernel_w = kernel.shape
+ batch, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+ out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+
+ return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ if input.device.type == 'cpu':
+ out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+ else:
+ out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
+
+ return out
+
+
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/repositories/CodeFormer/basicsr/setup.py b/repositories/CodeFormer/basicsr/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..382a2aa1006e581eaf31dbb3155d4b0ba3b31140
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/setup.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python
+
+from setuptools import find_packages, setup
+
+import os
+import subprocess
+import sys
+import time
+import torch
+from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
+
+version_file = './basicsr/version.py'
+
+
+def readme():
+ with open('README.md', encoding='utf-8') as f:
+ content = f.read()
+ return content
+
+
+def get_git_hash():
+
+ def _minimal_ext_cmd(cmd):
+ # construct minimal environment
+ env = {}
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
+ v = os.environ.get(k)
+ if v is not None:
+ env[k] = v
+ # LANGUAGE is used on win32
+ env['LANGUAGE'] = 'C'
+ env['LANG'] = 'C'
+ env['LC_ALL'] = 'C'
+ out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+ return out
+
+ try:
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
+ sha = out.strip().decode('ascii')
+ except OSError:
+ sha = 'unknown'
+
+ return sha
+
+
+def get_hash():
+ if os.path.exists('.git'):
+ sha = get_git_hash()[:7]
+ elif os.path.exists(version_file):
+ try:
+ from version import __version__
+ sha = __version__.split('+')[-1]
+ except ImportError:
+ raise ImportError('Unable to get git version')
+ else:
+ sha = 'unknown'
+
+ return sha
+
+
+def write_version_py():
+ content = """# GENERATED VERSION FILE
+# TIME: {}
+__version__ = '{}'
+__gitsha__ = '{}'
+version_info = ({})
+"""
+ sha = get_hash()
+ with open('./basicsr/VERSION', 'r') as f:
+ SHORT_VERSION = f.read().strip()
+ VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
+
+ version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
+ with open(version_file, 'w') as f:
+ f.write(version_file_str)
+
+
+def get_version():
+ with open(version_file, 'r') as f:
+ exec(compile(f.read(), version_file, 'exec'))
+ return locals()['__version__']
+
+
+def make_cuda_ext(name, module, sources, sources_cuda=None):
+ if sources_cuda is None:
+ sources_cuda = []
+ define_macros = []
+ extra_compile_args = {'cxx': []}
+
+ if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
+ define_macros += [('WITH_CUDA', None)]
+ extension = CUDAExtension
+ extra_compile_args['nvcc'] = [
+ '-D__CUDA_NO_HALF_OPERATORS__',
+ '-D__CUDA_NO_HALF_CONVERSIONS__',
+ '-D__CUDA_NO_HALF2_OPERATORS__',
+ ]
+ sources += sources_cuda
+ else:
+ print(f'Compiling {name} without CUDA')
+ extension = CppExtension
+
+ return extension(
+ name=f'{module}.{name}',
+ sources=[os.path.join(*module.split('.'), p) for p in sources],
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args)
+
+
+def get_requirements(filename='requirements.txt'):
+ with open(os.path.join('.', filename), 'r') as f:
+ requires = [line.replace('\n', '') for line in f.readlines()]
+ return requires
+
+
+if __name__ == '__main__':
+ if '--cuda_ext' in sys.argv:
+ ext_modules = [
+ make_cuda_ext(
+ name='deform_conv_ext',
+ module='ops.dcn',
+ sources=['src/deform_conv_ext.cpp'],
+ sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
+ make_cuda_ext(
+ name='fused_act_ext',
+ module='ops.fused_act',
+ sources=['src/fused_bias_act.cpp'],
+ sources_cuda=['src/fused_bias_act_kernel.cu']),
+ make_cuda_ext(
+ name='upfirdn2d_ext',
+ module='ops.upfirdn2d',
+ sources=['src/upfirdn2d.cpp'],
+ sources_cuda=['src/upfirdn2d_kernel.cu']),
+ ]
+ sys.argv.remove('--cuda_ext')
+ else:
+ ext_modules = []
+
+ write_version_py()
+ setup(
+ name='basicsr',
+ version=get_version(),
+ description='Open Source Image and Video Super-Resolution Toolbox',
+ long_description=readme(),
+ long_description_content_type='text/markdown',
+ author='Xintao Wang',
+ author_email='xintao.wang@outlook.com',
+ keywords='computer vision, restoration, super resolution',
+ url='https://github.com/xinntao/BasicSR',
+ include_package_data=True,
+ packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
+ classifiers=[
+ 'Development Status :: 4 - Beta',
+ 'License :: OSI Approved :: Apache Software License',
+ 'Operating System :: OS Independent',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ ],
+ license='Apache License 2.0',
+ setup_requires=['cython', 'numpy'],
+ install_requires=get_requirements(),
+ ext_modules=ext_modules,
+ cmdclass={'build_ext': BuildExtension},
+ zip_safe=False)
diff --git a/repositories/CodeFormer/basicsr/train.py b/repositories/CodeFormer/basicsr/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..a01c0dfccdb8b02283100ec5b792c33afaf22f5e
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/train.py
@@ -0,0 +1,225 @@
+import argparse
+import datetime
+import logging
+import math
+import copy
+import random
+import time
+import torch
+from os import path as osp
+
+from basicsr.data import build_dataloader, build_dataset
+from basicsr.data.data_sampler import EnlargedSampler
+from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
+from basicsr.models import build_model
+from basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger,
+ init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed)
+from basicsr.utils.dist_util import get_dist_info, init_dist
+from basicsr.utils.options import dict2str, parse
+
+import warnings
+# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`.
+warnings.filterwarnings("ignore", category=UserWarning)
+
+def parse_options(root_path, is_train=True):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
+ parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ args = parser.parse_args()
+ opt = parse(args.opt, root_path, is_train=is_train)
+
+ # distributed settings
+ if args.launcher == 'none':
+ opt['dist'] = False
+ print('Disable distributed.', flush=True)
+ else:
+ opt['dist'] = True
+ if args.launcher == 'slurm' and 'dist_params' in opt:
+ init_dist(args.launcher, **opt['dist_params'])
+ else:
+ init_dist(args.launcher)
+
+ opt['rank'], opt['world_size'] = get_dist_info()
+
+ # random seed
+ seed = opt.get('manual_seed')
+ if seed is None:
+ seed = random.randint(1, 10000)
+ opt['manual_seed'] = seed
+ set_random_seed(seed + opt['rank'])
+
+ return opt
+
+
+def init_loggers(opt):
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(opt))
+
+ # initialize wandb logger before tensorboard logger to allow proper sync:
+ if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None):
+ assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
+ init_wandb_logger(opt)
+ tb_logger = None
+ if opt['logger'].get('use_tb_logger'):
+ tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
+ return logger, tb_logger
+
+
+def create_train_val_dataloader(opt, logger):
+ # create train and val dataloaders
+ train_loader, val_loader = None, None
+ for phase, dataset_opt in opt['datasets'].items():
+ if phase == 'train':
+ dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
+ train_set = build_dataset(dataset_opt)
+ train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
+ train_loader = build_dataloader(
+ train_set,
+ dataset_opt,
+ num_gpu=opt['num_gpu'],
+ dist=opt['dist'],
+ sampler=train_sampler,
+ seed=opt['manual_seed'])
+
+ num_iter_per_epoch = math.ceil(
+ len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
+ total_iters = int(opt['train']['total_iter'])
+ total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
+ logger.info('Training statistics:'
+ f'\n\tNumber of train images: {len(train_set)}'
+ f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
+ f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
+ f'\n\tWorld size (gpu number): {opt["world_size"]}'
+ f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
+ f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
+
+ elif phase == 'val':
+ val_set = build_dataset(dataset_opt)
+ val_loader = build_dataloader(
+ val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
+ logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}')
+ else:
+ raise ValueError(f'Dataset phase {phase} is not recognized.')
+
+ return train_loader, train_sampler, val_loader, total_epochs, total_iters
+
+
+def train_pipeline(root_path):
+ # parse options, set distributed setting, set ramdom seed
+ opt = parse_options(root_path, is_train=True)
+
+ torch.backends.cudnn.benchmark = True
+ # torch.backends.cudnn.deterministic = True
+
+ # load resume states if necessary
+ if opt['path'].get('resume_state'):
+ device_id = torch.cuda.current_device()
+ resume_state = torch.load(
+ opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ resume_state = None
+
+ # mkdir for experiments and logger
+ if resume_state is None:
+ make_exp_dirs(opt)
+ if opt['logger'].get('use_tb_logger') and opt['rank'] == 0:
+ mkdir_and_rename(osp.join('tb_logger', opt['name']))
+
+ # initialize loggers
+ logger, tb_logger = init_loggers(opt)
+
+ # create train and validation dataloaders
+ result = create_train_val_dataloader(opt, logger)
+ train_loader, train_sampler, val_loader, total_epochs, total_iters = result
+
+ # create model
+ if resume_state: # resume training
+ check_resume(opt, resume_state['iter'])
+ model = build_model(opt)
+ model.resume_training(resume_state) # handle optimizers and schedulers
+ logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
+ start_epoch = resume_state['epoch']
+ current_iter = resume_state['iter']
+ else:
+ model = build_model(opt)
+ start_epoch = 0
+ current_iter = 0
+
+ # create message logger (formatted outputs)
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
+
+ # dataloader prefetcher
+ prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
+ if prefetch_mode is None or prefetch_mode == 'cpu':
+ prefetcher = CPUPrefetcher(train_loader)
+ elif prefetch_mode == 'cuda':
+ prefetcher = CUDAPrefetcher(train_loader, opt)
+ logger.info(f'Use {prefetch_mode} prefetch dataloader')
+ if opt['datasets']['train'].get('pin_memory') is not True:
+ raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
+ else:
+ raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.")
+
+ # training
+ logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}')
+ data_time, iter_time = time.time(), time.time()
+ start_time = time.time()
+
+ for epoch in range(start_epoch, total_epochs + 1):
+ train_sampler.set_epoch(epoch)
+ prefetcher.reset()
+ train_data = prefetcher.next()
+
+ while train_data is not None:
+ data_time = time.time() - data_time
+
+ current_iter += 1
+ if current_iter > total_iters:
+ break
+ # update learning rate
+ model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
+ # training
+ model.feed_data(train_data)
+ model.optimize_parameters(current_iter)
+ iter_time = time.time() - iter_time
+ # log
+ if current_iter % opt['logger']['print_freq'] == 0:
+ log_vars = {'epoch': epoch, 'iter': current_iter}
+ log_vars.update({'lrs': model.get_current_learning_rate()})
+ log_vars.update({'time': iter_time, 'data_time': data_time})
+ log_vars.update(model.get_current_log())
+ msg_logger(log_vars)
+
+ # save models and training states
+ if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
+ logger.info('Saving models and training states.')
+ model.save(epoch, current_iter)
+
+ # validation
+ if opt.get('val') is not None and opt['datasets'].get('val') is not None \
+ and (current_iter % opt['val']['val_freq'] == 0):
+ model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
+
+ data_time = time.time()
+ iter_time = time.time()
+ train_data = prefetcher.next()
+ # end of iter
+
+ # end of epoch
+
+ consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
+ logger.info(f'End of training. Time consumed: {consumed_time}')
+ logger.info('Save the latest model.')
+ model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
+ if opt.get('val') is not None and opt['datasets'].get('val'):
+ model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
+ if tb_logger:
+ tb_logger.close()
+
+
+if __name__ == '__main__':
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
+ train_pipeline(root_path)
diff --git a/repositories/CodeFormer/basicsr/utils/__init__.py b/repositories/CodeFormer/basicsr/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fcc1d540462712387523d1e326d1dfc2bcfbf32
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/utils/__init__.py
@@ -0,0 +1,29 @@
+from .file_client import FileClient
+from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
+from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
+from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
+
+__all__ = [
+ # file_client.py
+ 'FileClient',
+ # img_util.py
+ 'img2tensor',
+ 'tensor2img',
+ 'imfrombytes',
+ 'imwrite',
+ 'crop_border',
+ # logger.py
+ 'MessageLogger',
+ 'init_tb_logger',
+ 'init_wandb_logger',
+ 'get_root_logger',
+ 'get_env_info',
+ # misc.py
+ 'set_random_seed',
+ 'get_time_str',
+ 'mkdir_and_rename',
+ 'make_exp_dirs',
+ 'scandir',
+ 'check_resume',
+ 'sizeof_fmt'
+]
diff --git a/repositories/CodeFormer/basicsr/utils/dist_util.py b/repositories/CodeFormer/basicsr/utils/dist_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fab887b2cb1ce8533d2e8fdee72ae0c24f68fd0
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/utils/dist_util.py
@@ -0,0 +1,82 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
+import functools
+import os
+import subprocess
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ if dist.is_available():
+ initialized = dist.is_initialized()
+ else:
+ initialized = False
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def master_only(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/repositories/CodeFormer/basicsr/utils/download_util.py b/repositories/CodeFormer/basicsr/utils/download_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a267915743ee3f3232bc8fe992466b52468979a
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/utils/download_util.py
@@ -0,0 +1,95 @@
+import math
+import os
+import requests
+from torch.hub import download_url_to_file, get_dir
+from tqdm import tqdm
+from urllib.parse import urlparse
+
+from .misc import sizeof_fmt
+
+
+def download_file_from_google_drive(file_id, save_path):
+ """Download files from google drive.
+ Ref:
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
+ Args:
+ file_id (str): File id.
+ save_path (str): Save path.
+ """
+
+ session = requests.Session()
+ URL = 'https://docs.google.com/uc?export=download'
+ params = {'id': file_id}
+
+ response = session.get(URL, params=params, stream=True)
+ token = get_confirm_token(response)
+ if token:
+ params['confirm'] = token
+ response = session.get(URL, params=params, stream=True)
+
+ # get file size
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
+ print(response_file_size)
+ if 'Content-Range' in response_file_size.headers:
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
+ else:
+ file_size = None
+
+ save_response_content(response, save_path, file_size)
+
+
+def get_confirm_token(response):
+ for key, value in response.cookies.items():
+ if key.startswith('download_warning'):
+ return value
+ return None
+
+
+def save_response_content(response, destination, file_size=None, chunk_size=32768):
+ if file_size is not None:
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
+
+ readable_file_size = sizeof_fmt(file_size)
+ else:
+ pbar = None
+
+ with open(destination, 'wb') as f:
+ downloaded_size = 0
+ for chunk in response.iter_content(chunk_size):
+ downloaded_size += chunk_size
+ if pbar is not None:
+ pbar.update(1)
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ if pbar is not None:
+ pbar.close()
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+ Returns:
+ str: The path to the downloaded file.
+ """
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
\ No newline at end of file
diff --git a/repositories/CodeFormer/basicsr/utils/file_client.py b/repositories/CodeFormer/basicsr/utils/file_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f38d9796da3899048924f2f803d1088927966b0
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/utils/file_client.py
@@ -0,0 +1,167 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
+from abc import ABCMeta, abstractmethod
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError('Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'r') as f:
+ value_buf = f.read()
+ return value_buf
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_paths (str | list[str]): Lmdb database paths.
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_paths (list): Lmdb database path.
+ _client (list): A list of several lmdb envs.
+ """
+
+ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ if isinstance(client_keys, str):
+ client_keys = [client_keys]
+
+ if isinstance(db_paths, list):
+ self.db_paths = [str(v) for v in db_paths]
+ elif isinstance(db_paths, str):
+ self.db_paths = [str(db_paths)]
+ assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
+
+ self._client = {}
+ for client, path in zip(client_keys, self.db_paths):
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
+
+ def get(self, filepath, client_key):
+ """Get values according to the filepath from one lmdb named client_key.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ client_key (str): Used for distinguishing differnet lmdb envs.
+ """
+ filepath = str(filepath)
+ assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
+ client = self._client[client_key]
+ with client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode('ascii'))
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class FileClient(object):
+ """A general file client to access files in different backend.
+
+ The client loads a file or text in a specified backend from its path
+ and return it as a binary file. it can also register other backend
+ accessor with a given name and backend class.
+
+ Attributes:
+ backend (str): The storage backend type. Options are "disk",
+ "memcached" and "lmdb".
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ }
+
+ def __init__(self, backend='disk', **kwargs):
+ if backend not in self._backends:
+ raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(self._backends.keys())}')
+ self.backend = backend
+ self.client = self._backends[backend](**kwargs)
+
+ def get(self, filepath, client_key='default'):
+ # client_key is used only for lmdb, where different fileclients have
+ # different lmdb environments.
+ if self.backend == 'lmdb':
+ return self.client.get(filepath, client_key)
+ else:
+ return self.client.get(filepath)
+
+ def get_text(self, filepath):
+ return self.client.get_text(filepath)
diff --git a/repositories/CodeFormer/basicsr/utils/img_util.py b/repositories/CodeFormer/basicsr/utils/img_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..d409a132ff216e6943a276fb5d8cd5f410824883
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/utils/img_util.py
@@ -0,0 +1,170 @@
+import cv2
+import math
+import numpy as np
+import os
+import torch
+from torchvision.utils import make_grid
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+ """Convert torch Tensors into image numpy arrays.
+
+ After clamping to [min, max], values will be normalized to [0, 1].
+
+ Args:
+ tensor (Tensor or list[Tensor]): Accept shapes:
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+ 2) 3D Tensor of shape (3/1 x H x W);
+ 3) 2D Tensor of shape (H x W).
+ Tensor channel should be in RGB order.
+ rgb2bgr (bool): Whether to change rgb to bgr.
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
+ to uint8 type with range [0, 255]; otherwise, float type with
+ range [0, 1]. Default: ``np.uint8``.
+ min_max (tuple[int]): min and max values for clamp.
+
+ Returns:
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+ shape (H x W). The channel order is BGR.
+ """
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
+
+ if torch.is_tensor(tensor):
+ tensor = [tensor]
+ result = []
+ for _tensor in tensor:
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+ n_dim = _tensor.dim()
+ if n_dim == 4:
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 3:
+ img_np = _tensor.numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image
+ img_np = np.squeeze(img_np, axis=2)
+ else:
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 2:
+ img_np = _tensor.numpy()
+ else:
+ raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}')
+ if out_type == np.uint8:
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+ img_np = (img_np * 255.0).round()
+ img_np = img_np.astype(out_type)
+ result.append(img_np)
+ if len(result) == 1:
+ result = result[0]
+ return result
+
+
+def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
+ """This implementation is slightly faster than tensor2img.
+ It now only supports torch tensor with shape (1, c, h, w).
+
+ Args:
+ tensor (Tensor): Now only support torch tensor with (1, c, h, w).
+ rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
+ min_max (tuple[int]): min and max values for clamp.
+ """
+ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
+ output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
+ output = output.type(torch.uint8).cpu().numpy()
+ if rgb2bgr:
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
+ return output
+
+
+def imfrombytes(content, flag='color', float32=False):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale` and `unchanged`.
+ float32 (bool): Whether to change to float32., If True, will also norm
+ to [0, 1]. Default: False.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+ img_np = np.frombuffer(content, np.uint8)
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
+ img = cv2.imdecode(img_np, imread_flags[flag])
+ if float32:
+ img = img.astype(np.float32) / 255.
+ return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def crop_border(imgs, crop_border):
+ """Crop borders of images.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
+ crop_border (int): Crop border for each end of height and weight.
+
+ Returns:
+ list[ndarray]: Cropped images.
+ """
+ if crop_border == 0:
+ return imgs
+ else:
+ if isinstance(imgs, list):
+ return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
+ else:
+ return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
diff --git a/repositories/CodeFormer/basicsr/utils/lmdb_util.py b/repositories/CodeFormer/basicsr/utils/lmdb_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0a10f60ffca2e36ac5f5564aafd70e79d06a723
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/utils/lmdb_util.py
@@ -0,0 +1,196 @@
+import cv2
+import lmdb
+import sys
+from multiprocessing import Pool
+from os import path as osp
+from tqdm import tqdm
+
+
+def make_lmdb_from_imgs(data_path,
+ lmdb_path,
+ img_path_list,
+ keys,
+ batch=5000,
+ compress_level=1,
+ multiprocessing_read=False,
+ n_thread=40,
+ map_size=None):
+ """Make lmdb from images.
+
+ Contents of lmdb. The file structure is:
+ example.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records 1)image name (with extension),
+ 2)image shape, and 3)compression level, separated by a white space.
+
+ For example, the meta information could be:
+ `000_00000000.png (720,1280,3) 1`, which means:
+ 1) image name (with extension): 000_00000000.png;
+ 2) image shape: (720,1280,3);
+ 3) compression level: 1
+
+ We use the image name without extension as the lmdb key.
+
+ If `multiprocessing_read` is True, it will read all the images to memory
+ using multiprocessing. Thus, your server needs to have enough memory.
+
+ Args:
+ data_path (str): Data path for reading images.
+ lmdb_path (str): Lmdb save path.
+ img_path_list (str): Image path list.
+ keys (str): Used for lmdb keys.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ multiprocessing_read (bool): Whether use multiprocessing to read all
+ the images to memory. Default: False.
+ n_thread (int): For multiprocessing.
+ map_size (int | None): Map size for lmdb env. If None, use the
+ estimated size from images. Default: None
+ """
+
+ assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
+ f'but got {len(img_path_list)} and {len(keys)}')
+ print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
+ print(f'Totoal images: {len(img_path_list)}')
+ if not lmdb_path.endswith('.lmdb'):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f'Folder {lmdb_path} already exists. Exit.')
+ sys.exit(1)
+
+ if multiprocessing_read:
+ # read all the images to memory (multiprocessing)
+ dataset = {} # use dict to keep the order for multiprocessing
+ shapes = {}
+ print(f'Read images with multiprocessing, #thread: {n_thread} ...')
+ pbar = tqdm(total=len(img_path_list), unit='image')
+
+ def callback(arg):
+ """get the image data and update pbar."""
+ key, dataset[key], shapes[key] = arg
+ pbar.update(1)
+ pbar.set_description(f'Read {key}')
+
+ pool = Pool(n_thread)
+ for path, key in zip(img_path_list, keys):
+ pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
+ pool.close()
+ pool.join()
+ pbar.close()
+ print(f'Finish reading {len(img_path_list)} images.')
+
+ # create lmdb environment
+ if map_size is None:
+ # obtain data size for one image
+ img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ data_size_per_img = img_byte.nbytes
+ print('Data size per image is: ', data_size_per_img)
+ data_size = data_size_per_img * len(img_path_list)
+ map_size = data_size * 10
+
+ env = lmdb.open(lmdb_path, map_size=map_size)
+
+ # write data to lmdb
+ pbar = tqdm(total=len(img_path_list), unit='chunk')
+ txn = env.begin(write=True)
+ txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+ for idx, (path, key) in enumerate(zip(img_path_list, keys)):
+ pbar.update(1)
+ pbar.set_description(f'Write {key}')
+ key_byte = key.encode('ascii')
+ if multiprocessing_read:
+ img_byte = dataset[key]
+ h, w, c = shapes[key]
+ else:
+ _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
+ h, w, c = img_shape
+
+ txn.put(key_byte, img_byte)
+ # write meta information
+ txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
+ if idx % batch == 0:
+ txn.commit()
+ txn = env.begin(write=True)
+ pbar.close()
+ txn.commit()
+ env.close()
+ txt_file.close()
+ print('\nFinish writing lmdb.')
+
+
+def read_img_worker(path, key, compress_level):
+ """Read image worker.
+
+ Args:
+ path (str): Image path.
+ key (str): Image key.
+ compress_level (int): Compress level when encoding images.
+
+ Returns:
+ str: Image key.
+ byte: Image byte.
+ tuple[int]: Image shape.
+ """
+
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ if img.ndim == 2:
+ h, w = img.shape
+ c = 1
+ else:
+ h, w, c = img.shape
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ return (key, img_byte, (h, w, c))
+
+
+class LmdbMaker():
+ """LMDB Maker.
+
+ Args:
+ lmdb_path (str): Lmdb save path.
+ map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ """
+
+ def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
+ if not lmdb_path.endswith('.lmdb'):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f'Folder {lmdb_path} already exists. Exit.')
+ sys.exit(1)
+
+ self.lmdb_path = lmdb_path
+ self.batch = batch
+ self.compress_level = compress_level
+ self.env = lmdb.open(lmdb_path, map_size=map_size)
+ self.txn = self.env.begin(write=True)
+ self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+ self.counter = 0
+
+ def put(self, img_byte, key, img_shape):
+ self.counter += 1
+ key_byte = key.encode('ascii')
+ self.txn.put(key_byte, img_byte)
+ # write meta information
+ h, w, c = img_shape
+ self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
+ if self.counter % self.batch == 0:
+ self.txn.commit()
+ self.txn = self.env.begin(write=True)
+
+ def close(self):
+ self.txn.commit()
+ self.env.close()
+ self.txt_file.close()
diff --git a/repositories/CodeFormer/basicsr/utils/logger.py b/repositories/CodeFormer/basicsr/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..9714bf59c30fc82de24c1ee58d9118d0864b3572
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/utils/logger.py
@@ -0,0 +1,169 @@
+import datetime
+import logging
+import time
+
+from .dist_util import get_dist_info, master_only
+
+initialized_logger = {}
+
+
+class MessageLogger():
+ """Message logger for printing.
+ Args:
+ opt (dict): Config. It contains the following keys:
+ name (str): Exp name.
+ logger (dict): Contains 'print_freq' (str) for logger interval.
+ train (dict): Contains 'total_iter' (int) for total iters.
+ use_tb_logger (bool): Use tensorboard logger.
+ start_iter (int): Start iter. Default: 1.
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
+ """
+
+ def __init__(self, opt, start_iter=1, tb_logger=None):
+ self.exp_name = opt['name']
+ self.interval = opt['logger']['print_freq']
+ self.start_iter = start_iter
+ self.max_iters = opt['train']['total_iter']
+ self.use_tb_logger = opt['logger']['use_tb_logger']
+ self.tb_logger = tb_logger
+ self.start_time = time.time()
+ self.logger = get_root_logger()
+
+ @master_only
+ def __call__(self, log_vars):
+ """Format logging message.
+ Args:
+ log_vars (dict): It contains the following keys:
+ epoch (int): Epoch number.
+ iter (int): Current iter.
+ lrs (list): List for learning rates.
+ time (float): Iter time.
+ data_time (float): Data time for each iter.
+ """
+ # epoch, iter, learning rates
+ epoch = log_vars.pop('epoch')
+ current_iter = log_vars.pop('iter')
+ lrs = log_vars.pop('lrs')
+
+ message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
+ for v in lrs:
+ message += f'{v:.3e},'
+ message += ')] '
+
+ # time and estimated time
+ if 'time' in log_vars.keys():
+ iter_time = log_vars.pop('time')
+ data_time = log_vars.pop('data_time')
+
+ total_time = time.time() - self.start_time
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+ message += f'[eta: {eta_str}, '
+ message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
+
+ # other items, especially losses
+ for k, v in log_vars.items():
+ message += f'{k}: {v:.4e} '
+ # tensorboard logger
+ if self.use_tb_logger:
+ if k.startswith('l_'):
+ self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
+ else:
+ self.tb_logger.add_scalar(k, v, current_iter)
+ self.logger.info(message)
+
+
+@master_only
+def init_tb_logger(log_dir):
+ from torch.utils.tensorboard import SummaryWriter
+ tb_logger = SummaryWriter(log_dir=log_dir)
+ return tb_logger
+
+
+@master_only
+def init_wandb_logger(opt):
+ """We now only use wandb to sync tensorboard log."""
+ import wandb
+ logger = logging.getLogger('basicsr')
+
+ project = opt['logger']['wandb']['project']
+ resume_id = opt['logger']['wandb'].get('resume_id')
+ if resume_id:
+ wandb_id = resume_id
+ resume = 'allow'
+ logger.warning(f'Resume wandb logger with id={wandb_id}.')
+ else:
+ wandb_id = wandb.util.generate_id()
+ resume = 'never'
+
+ wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
+
+ logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
+
+
+def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
+ """Get the root logger.
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added.
+ Args:
+ logger_name (str): root logger name. Default: 'basicsr'.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+ Returns:
+ logging.Logger: The root logger.
+ """
+ logger = logging.getLogger(logger_name)
+ # if the logger has been initialized, just return it
+ if logger_name in initialized_logger:
+ return logger
+
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
+ stream_handler = logging.StreamHandler()
+ stream_handler.setFormatter(logging.Formatter(format_str))
+ logger.addHandler(stream_handler)
+ logger.propagate = False
+ rank, _ = get_dist_info()
+ if rank != 0:
+ logger.setLevel('ERROR')
+ elif log_file is not None:
+ logger.setLevel(log_level)
+ # add file handler
+ # file_handler = logging.FileHandler(log_file, 'w')
+ file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
+ file_handler.setFormatter(logging.Formatter(format_str))
+ file_handler.setLevel(log_level)
+ logger.addHandler(file_handler)
+ initialized_logger[logger_name] = True
+ return logger
+
+
+def get_env_info():
+ """Get environment information.
+ Currently, only log the software version.
+ """
+ import torch
+ import torchvision
+
+ from basicsr.version import __version__
+ msg = r"""
+ ____ _ _____ ____
+ / __ ) ____ _ _____ (_)_____/ ___/ / __ \
+ / __ |/ __ `// ___// // ___/\__ \ / /_/ /
+ / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
+ /_____/ \__,_//____//_/ \___//____//_/ |_|
+ ______ __ __ __ __
+ / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
+ / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
+ / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
+ \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
+ """
+ msg += ('\nVersion Information: '
+ f'\n\tBasicSR: {__version__}'
+ f'\n\tPyTorch: {torch.__version__}'
+ f'\n\tTorchVision: {torchvision.__version__}')
+ return msg
\ No newline at end of file
diff --git a/repositories/CodeFormer/basicsr/utils/matlab_functions.py b/repositories/CodeFormer/basicsr/utils/matlab_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6ce1004a2c9f8521505c4b5889d3c24a909c70d
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/utils/matlab_functions.py
@@ -0,0 +1,347 @@
+import math
+import numpy as np
+import torch
+
+
+def cubic(x):
+ """cubic function used for calculate_weights_indices."""
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
+ (absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ """Calculate weights and indices, used for imresize function.
+
+ Args:
+ in_length (int): Input length.
+ out_length (int): Output length.
+ scale (float): Scale factor.
+ kernel_width (int): Kernel width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ """
+
+ if (scale < 1) and antialiasing:
+ # Use a modified kernel (larger kernel width) to simultaneously
+ # interpolate and antialias
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ p = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
+ out_length, p)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
+
+ # apply cubic kernel
+ if (scale < 1) and antialiasing:
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, p)
+
+ # If a column in weights is all zero, get rid of it. only consider the
+ # first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, p - 2)
+ weights = weights.narrow(1, 1, p - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, p - 2)
+ weights = weights.narrow(1, 0, p - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+@torch.no_grad()
+def imresize(img, scale, antialiasing=True):
+ """imresize function same as MATLAB.
+
+ It now only supports bicubic.
+ The same scale applies for both height and width.
+
+ Args:
+ img (Tensor | Numpy array):
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
+ scale (float): Scale factor. The same scale applies for both height
+ and width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ Default: True.
+
+ Returns:
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
+ """
+ if type(img).__module__ == np.__name__: # numpy type
+ numpy_type = True
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
+ else:
+ numpy_type = False
+
+ in_c, in_h, in_w = img.size()
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # get weights and indices
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
+ antialiasing)
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
+ antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
+
+ sym_patch = img[:, :sym_len_hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_he:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
+ kernel_width = weights_h.size(1)
+ for i in range(out_h):
+ idx = int(indices_h[i][0])
+ for j in range(in_c):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_we:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
+ kernel_width = weights_w.size(1)
+ for i in range(out_w):
+ idx = int(indices_w[i][0])
+ for j in range(in_c):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
+
+ if numpy_type:
+ out_2 = out_2.numpy().transpose(1, 2, 0)
+ return out_2
+
+
+def rgb2ycbcr(img, y_only=False):
+ """Convert a RGB image to YCbCr image.
+
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+ """Convert a BGR image to YCbCr image.
+
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2rgb(img):
+ """Convert a YCbCr image to RGB image.
+
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2bgr(img):
+ """Convert a YCbCr image to BGR image.
+
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def _convert_input_type_range(img):
+ """Convert the type and range of the input image.
+
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ convertion functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.
+ else:
+ raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
+ return img
+
+
+def _convert_output_type_range(img, dst_type):
+ """Convert the type and range of the image according to dst_type.
+
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace convertion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.
+ return img.astype(dst_type)
diff --git a/repositories/CodeFormer/basicsr/utils/misc.py b/repositories/CodeFormer/basicsr/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b444ff3b950e38f43a5451d1330ff1b65951a9e
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/utils/misc.py
@@ -0,0 +1,134 @@
+import numpy as np
+import os
+import random
+import time
+import torch
+from os import path as osp
+
+from .dist_util import master_only
+from .logger import get_root_logger
+
+
+def set_random_seed(seed):
+ """Set random seeds."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_time_str():
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def mkdir_and_rename(path):
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
+
+ Args:
+ path (str): Folder path.
+ """
+ if osp.exists(path):
+ new_name = path + '_archived_' + get_time_str()
+ print(f'Path already exists. Rename it to {new_name}', flush=True)
+ os.rename(path, new_name)
+ os.makedirs(path, exist_ok=True)
+
+
+@master_only
+def make_exp_dirs(opt):
+ """Make dirs for experiments."""
+ path_opt = opt['path'].copy()
+ if opt['is_train']:
+ mkdir_and_rename(path_opt.pop('experiments_root'))
+ else:
+ mkdir_and_rename(path_opt.pop('results_root'))
+ for key, path in path_opt.items():
+ if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
+ os.makedirs(path, exist_ok=True)
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+
+ Returns:
+ A generator for all the interested files with relative pathes.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def check_resume(opt, resume_iter):
+ """Check resume states and pretrain_network paths.
+
+ Args:
+ opt (dict): Options.
+ resume_iter (int): Resume iteration.
+ """
+ logger = get_root_logger()
+ if opt['path']['resume_state']:
+ # get all the networks
+ networks = [key for key in opt.keys() if key.startswith('network_')]
+ flag_pretrain = False
+ for network in networks:
+ if opt['path'].get(f'pretrain_{network}') is not None:
+ flag_pretrain = True
+ if flag_pretrain:
+ logger.warning('pretrain_network path will be ignored during resuming.')
+ # set pretrained model paths
+ for network in networks:
+ name = f'pretrain_{network}'
+ basename = network.replace('network_', '')
+ if opt['path'].get('ignore_resume_networks') is None or (basename
+ not in opt['path']['ignore_resume_networks']):
+ opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
+ logger.info(f"Set {name} to {opt['path'][name]}")
+
+
+def sizeof_fmt(size, suffix='B'):
+ """Get human readable file size.
+
+ Args:
+ size (int): File size.
+ suffix (str): Suffix. Default: 'B'.
+
+ Return:
+ str: Formated file siz.
+ """
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
+ if abs(size) < 1024.0:
+ return f'{size:3.1f} {unit}{suffix}'
+ size /= 1024.0
+ return f'{size:3.1f} Y{suffix}'
diff --git a/repositories/CodeFormer/basicsr/utils/options.py b/repositories/CodeFormer/basicsr/utils/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..db490e4aa52e26fde31959fd74c2cef3af2ecf76
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/utils/options.py
@@ -0,0 +1,108 @@
+import yaml
+import time
+from collections import OrderedDict
+from os import path as osp
+from basicsr.utils.misc import get_time_str
+
+def ordered_yaml():
+ """Support OrderedDict for yaml.
+
+ Returns:
+ yaml Loader and Dumper.
+ """
+ try:
+ from yaml import CDumper as Dumper
+ from yaml import CLoader as Loader
+ except ImportError:
+ from yaml import Dumper, Loader
+
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
+
+ def dict_representer(dumper, data):
+ return dumper.represent_dict(data.items())
+
+ def dict_constructor(loader, node):
+ return OrderedDict(loader.construct_pairs(node))
+
+ Dumper.add_representer(OrderedDict, dict_representer)
+ Loader.add_constructor(_mapping_tag, dict_constructor)
+ return Loader, Dumper
+
+
+def parse(opt_path, root_path, is_train=True):
+ """Parse option file.
+
+ Args:
+ opt_path (str): Option file path.
+ is_train (str): Indicate whether in training or not. Default: True.
+
+ Returns:
+ (dict): Options.
+ """
+ with open(opt_path, mode='r') as f:
+ Loader, _ = ordered_yaml()
+ opt = yaml.load(f, Loader=Loader)
+
+ opt['is_train'] = is_train
+
+ # opt['name'] = f"{get_time_str()}_{opt['name']}"
+ if opt['path'].get('resume_state', None): # Shangchen added
+ resume_state_path = opt['path'].get('resume_state')
+ opt['name'] = resume_state_path.split("/")[-3]
+ else:
+ opt['name'] = f"{get_time_str()}_{opt['name']}"
+
+
+ # datasets
+ for phase, dataset in opt['datasets'].items():
+ # for several datasets, e.g., test_1, test_2
+ phase = phase.split('_')[0]
+ dataset['phase'] = phase
+ if 'scale' in opt:
+ dataset['scale'] = opt['scale']
+ if dataset.get('dataroot_gt') is not None:
+ dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
+ if dataset.get('dataroot_lq') is not None:
+ dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
+
+ # paths
+ for key, val in opt['path'].items():
+ if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
+ opt['path'][key] = osp.expanduser(val)
+
+ if is_train:
+ experiments_root = osp.join(root_path, 'experiments', opt['name'])
+ opt['path']['experiments_root'] = experiments_root
+ opt['path']['models'] = osp.join(experiments_root, 'models')
+ opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
+ opt['path']['log'] = experiments_root
+ opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
+
+ else: # test
+ results_root = osp.join(root_path, 'results', opt['name'])
+ opt['path']['results_root'] = results_root
+ opt['path']['log'] = results_root
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
+
+ return opt
+
+
+def dict2str(opt, indent_level=1):
+ """dict to string for printing options.
+
+ Args:
+ opt (dict): Option dict.
+ indent_level (int): Indent level. Default: 1.
+
+ Return:
+ (str): Option string for printing.
+ """
+ msg = '\n'
+ for k, v in opt.items():
+ if isinstance(v, dict):
+ msg += ' ' * (indent_level * 2) + k + ':['
+ msg += dict2str(v, indent_level + 1)
+ msg += ' ' * (indent_level * 2) + ']\n'
+ else:
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
+ return msg
diff --git a/repositories/CodeFormer/basicsr/utils/realesrgan_utils.py b/repositories/CodeFormer/basicsr/utils/realesrgan_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff94523b7ddd61f0b72280950fd36e1b8133bf4c
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/utils/realesrgan_utils.py
@@ -0,0 +1,296 @@
+import cv2
+import math
+import numpy as np
+import os
+import queue
+import threading
+import torch
+from basicsr.utils.download_util import load_file_from_url
+from torch.nn import functional as F
+
+# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+
+class RealESRGANer():
+ """A helper class for upsampling images with RealESRGAN.
+
+ Args:
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
+ model (nn.Module): The defined network. Default: None.
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
+ 0 denotes for do not use tile. Default: 0.
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
+ half (float): Whether to use half precision during inference. Default: False.
+ """
+
+ def __init__(self,
+ scale,
+ model_path,
+ model=None,
+ tile=0,
+ tile_pad=10,
+ pre_pad=10,
+ half=False,
+ device=None,
+ gpu_id=None):
+ self.scale = scale
+ self.tile_size = tile
+ self.tile_pad = tile_pad
+ self.pre_pad = pre_pad
+ self.mod_scale = None
+ self.half = half
+
+ # initialize model
+ if gpu_id:
+ self.device = torch.device(
+ f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
+ else:
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
+ # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
+ if model_path.startswith('https://'):
+ model_path = load_file_from_url(
+ url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
+ loadnet = torch.load(model_path, map_location=torch.device('cpu'))
+ # prefer to use params_ema
+ if 'params_ema' in loadnet:
+ keyname = 'params_ema'
+ else:
+ keyname = 'params'
+ model.load_state_dict(loadnet[keyname], strict=True)
+ model.eval()
+ self.model = model.to(self.device)
+ if self.half:
+ self.model = self.model.half()
+
+ def pre_process(self, img):
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
+ """
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
+ self.img = img.unsqueeze(0).to(self.device)
+ if self.half:
+ self.img = self.img.half()
+
+ # pre_pad
+ if self.pre_pad != 0:
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
+ # mod pad for divisible borders
+ if self.scale == 2:
+ self.mod_scale = 2
+ elif self.scale == 1:
+ self.mod_scale = 4
+ if self.mod_scale is not None:
+ self.mod_pad_h, self.mod_pad_w = 0, 0
+ _, _, h, w = self.img.size()
+ if (h % self.mod_scale != 0):
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
+ if (w % self.mod_scale != 0):
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
+
+ def process(self):
+ # model inference
+ self.output = self.model(self.img)
+
+ def tile_process(self):
+ """It will first crop input images to tiles, and then process each tile.
+ Finally, all the processed tiles are merged into one images.
+
+ Modified from: https://github.com/ata4/esrgan-launcher
+ """
+ batch, channel, height, width = self.img.shape
+ output_height = height * self.scale
+ output_width = width * self.scale
+ output_shape = (batch, channel, output_height, output_width)
+
+ # start with black image
+ self.output = self.img.new_zeros(output_shape)
+ tiles_x = math.ceil(width / self.tile_size)
+ tiles_y = math.ceil(height / self.tile_size)
+
+ # loop over all tiles
+ for y in range(tiles_y):
+ for x in range(tiles_x):
+ # extract tile from input image
+ ofs_x = x * self.tile_size
+ ofs_y = y * self.tile_size
+ # input tile area on total image
+ input_start_x = ofs_x
+ input_end_x = min(ofs_x + self.tile_size, width)
+ input_start_y = ofs_y
+ input_end_y = min(ofs_y + self.tile_size, height)
+
+ # input tile area on total image with padding
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
+
+ # input tile dimensions
+ input_tile_width = input_end_x - input_start_x
+ input_tile_height = input_end_y - input_start_y
+ tile_idx = y * tiles_x + x + 1
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
+
+ # upscale tile
+ try:
+ with torch.no_grad():
+ output_tile = self.model(input_tile)
+ except RuntimeError as error:
+ print('Error', error)
+ # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
+
+ # output tile area on total image
+ output_start_x = input_start_x * self.scale
+ output_end_x = input_end_x * self.scale
+ output_start_y = input_start_y * self.scale
+ output_end_y = input_end_y * self.scale
+
+ # output tile area without padding
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
+
+ # put tile into output image
+ self.output[:, :, output_start_y:output_end_y,
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
+ output_start_x_tile:output_end_x_tile]
+
+ def post_process(self):
+ # remove extra pad
+ if self.mod_scale is not None:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
+ # remove prepad
+ if self.pre_pad != 0:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
+ return self.output
+
+ @torch.no_grad()
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
+ h_input, w_input = img.shape[0:2]
+ # img: numpy
+ img = img.astype(np.float32)
+ if np.max(img) > 256: # 16-bit image
+ max_range = 65535
+ print('\tInput is a 16-bit image')
+ else:
+ max_range = 255
+ img = img / max_range
+ if len(img.shape) == 2: # gray image
+ img_mode = 'L'
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ elif img.shape[2] == 4: # RGBA image with alpha channel
+ img_mode = 'RGBA'
+ alpha = img[:, :, 3]
+ img = img[:, :, 0:3]
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ if alpha_upsampler == 'realesrgan':
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
+ else:
+ img_mode = 'RGB'
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ # ------------------- process image (without the alpha channel) ------------------- #
+ with torch.no_grad():
+ self.pre_process(img)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_img_t = self.post_process()
+ output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
+ if img_mode == 'L':
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
+ del output_img_t
+ torch.cuda.empty_cache()
+
+ # ------------------- process the alpha channel if necessary ------------------- #
+ if img_mode == 'RGBA':
+ if alpha_upsampler == 'realesrgan':
+ self.pre_process(alpha)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_alpha = self.post_process()
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
+ else: # use the cv2 resize for alpha channel
+ h, w = alpha.shape[0:2]
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
+
+ # merge the alpha channel
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
+ output_img[:, :, 3] = output_alpha
+
+ # ------------------------------ return ------------------------------ #
+ if max_range == 65535: # 16-bit image
+ output = (output_img * 65535.0).round().astype(np.uint16)
+ else:
+ output = (output_img * 255.0).round().astype(np.uint8)
+
+ if outscale is not None and outscale != float(self.scale):
+ output = cv2.resize(
+ output, (
+ int(w_input * outscale),
+ int(h_input * outscale),
+ ), interpolation=cv2.INTER_LANCZOS4)
+
+ return output, img_mode
+
+
+class PrefetchReader(threading.Thread):
+ """Prefetch images.
+
+ Args:
+ img_list (list[str]): A image list of image paths to be read.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, img_list, num_prefetch_queue):
+ super().__init__()
+ self.que = queue.Queue(num_prefetch_queue)
+ self.img_list = img_list
+
+ def run(self):
+ for img_path in self.img_list:
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+ self.que.put(img)
+
+ self.que.put(None)
+
+ def __next__(self):
+ next_item = self.que.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class IOConsumer(threading.Thread):
+
+ def __init__(self, opt, que, qid):
+ super().__init__()
+ self._queue = que
+ self.qid = qid
+ self.opt = opt
+
+ def run(self):
+ while True:
+ msg = self._queue.get()
+ if isinstance(msg, str) and msg == 'quit':
+ break
+
+ output = msg['output']
+ save_path = msg['save_path']
+ cv2.imwrite(save_path, output)
+ print(f'IO worker {self.qid} is done.')
\ No newline at end of file
diff --git a/repositories/CodeFormer/basicsr/utils/registry.py b/repositories/CodeFormer/basicsr/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..655753b3b9cbd0cfe73fe93a77cf1fcc3db6d827
--- /dev/null
+++ b/repositories/CodeFormer/basicsr/utils/registry.py
@@ -0,0 +1,82 @@
+# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
+
+
+class Registry():
+ """
+ The registry that provides name -> object mapping, to support third-party
+ users' custom modules.
+
+ To create a registry (e.g. a backbone registry):
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY = Registry('BACKBONE')
+
+ To register an object:
+
+ .. code-block:: python
+
+ @BACKBONE_REGISTRY.register()
+ class MyBackbone():
+ ...
+
+ Or:
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY.register(MyBackbone)
+ """
+
+ def __init__(self, name):
+ """
+ Args:
+ name (str): the name of this registry
+ """
+ self._name = name
+ self._obj_map = {}
+
+ def _do_register(self, name, obj):
+ assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
+ f"in '{self._name}' registry!")
+ self._obj_map[name] = obj
+
+ def register(self, obj=None):
+ """
+ Register the given object under the the name `obj.__name__`.
+ Can be used as either a decorator or not.
+ See docstring of this class for usage.
+ """
+ if obj is None:
+ # used as a decorator
+ def deco(func_or_class):
+ name = func_or_class.__name__
+ self._do_register(name, func_or_class)
+ return func_or_class
+
+ return deco
+
+ # used as a function call
+ name = obj.__name__
+ self._do_register(name, obj)
+
+ def get(self, name):
+ ret = self._obj_map.get(name)
+ if ret is None:
+ raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
+ return ret
+
+ def __contains__(self, name):
+ return name in self._obj_map
+
+ def __iter__(self):
+ return iter(self._obj_map.items())
+
+ def keys(self):
+ return self._obj_map.keys()
+
+
+DATASET_REGISTRY = Registry('dataset')
+ARCH_REGISTRY = Registry('arch')
+MODEL_REGISTRY = Registry('model')
+LOSS_REGISTRY = Registry('loss')
+METRIC_REGISTRY = Registry('metric')
diff --git a/repositories/CodeFormer/cog.yaml b/repositories/CodeFormer/cog.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..236299ed059180ebceee87963742ec0d555fceec
--- /dev/null
+++ b/repositories/CodeFormer/cog.yaml
@@ -0,0 +1,25 @@
+build:
+ gpu: true
+ cuda: "11.3"
+ python_version: "3.8"
+ system_packages:
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ python_packages:
+ - "ipython==8.4.0"
+ - "future==0.18.2"
+ - "lmdb==1.3.0"
+ - "scikit-image==0.19.3"
+ - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
+ - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
+ - "scipy==1.9.0"
+ - "gdown==4.5.1"
+ - "pyyaml==6.0"
+ - "tb-nightly==2.11.0a20220906"
+ - "tqdm==4.64.1"
+ - "yapf==0.32.0"
+ - "lpips==0.1.4"
+ - "Pillow==9.2.0"
+ - "opencv-python==4.6.0.66"
+
+predict: "predict.py:Predictor"
diff --git a/repositories/CodeFormer/facelib/detection/__init__.py b/repositories/CodeFormer/facelib/detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..296262d4e2e29eaa2afba7bda1f0399d77da24f6
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/__init__.py
@@ -0,0 +1,100 @@
+import os
+import torch
+from torch import nn
+from copy import deepcopy
+
+from facelib.utils import load_file_from_url
+from facelib.utils import download_pretrained_models
+from facelib.detection.yolov5face.models.common import Conv
+
+from .retinaface.retinaface import RetinaFace
+from .yolov5face.face_detector import YoloDetector
+
+
+def init_detection_model(model_name, half=False, device='cuda'):
+ if 'retinaface' in model_name:
+ model = init_retinaface_model(model_name, half, device)
+ elif 'YOLOv5' in model_name:
+ model = init_yolov5face_model(model_name, device)
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ return model
+
+
+def init_retinaface_model(model_name, half=False, device='cuda'):
+ if model_name == 'retinaface_resnet50':
+ model = RetinaFace(network_name='resnet50', half=half)
+ model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth'
+ elif model_name == 'retinaface_mobile0.25':
+ model = RetinaFace(network_name='mobile0.25', half=half)
+ model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
+ load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+ # remove unnecessary 'module.'
+ for k, v in deepcopy(load_net).items():
+ if k.startswith('module.'):
+ load_net[k[7:]] = v
+ load_net.pop(k)
+ model.load_state_dict(load_net, strict=True)
+ model.eval()
+ model = model.to(device)
+
+ return model
+
+
+def init_yolov5face_model(model_name, device='cuda'):
+ if model_name == 'YOLOv5l':
+ model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth'
+ elif model_name == 'YOLOv5n':
+ model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth'
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
+ load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+ model.detector.load_state_dict(load_net, strict=True)
+ model.detector.eval()
+ model.detector = model.detector.to(device).float()
+
+ for m in model.detector.modules():
+ if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
+ m.inplace = True # pytorch 1.7.0 compatibility
+ elif isinstance(m, Conv):
+ m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
+
+ return model
+
+
+# Download from Google Drive
+# def init_yolov5face_model(model_name, device='cuda'):
+# if model_name == 'YOLOv5l':
+# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
+# f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'}
+# elif model_name == 'YOLOv5n':
+# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
+# f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'}
+# else:
+# raise NotImplementedError(f'{model_name} is not implemented.')
+
+# model_path = os.path.join('weights/facelib', list(f_id.keys())[0])
+# if not os.path.exists(model_path):
+# download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib')
+
+# load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+# model.detector.load_state_dict(load_net, strict=True)
+# model.detector.eval()
+# model.detector = model.detector.to(device).float()
+
+# for m in model.detector.modules():
+# if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
+# m.inplace = True # pytorch 1.7.0 compatibility
+# elif isinstance(m, Conv):
+# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
+
+# return model
\ No newline at end of file
diff --git a/repositories/CodeFormer/facelib/detection/align_trans.py b/repositories/CodeFormer/facelib/detection/align_trans.py
new file mode 100644
index 0000000000000000000000000000000000000000..07f1eb365462c2ec5bbac6d1854c786b6fd6be90
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/align_trans.py
@@ -0,0 +1,219 @@
+import cv2
+import numpy as np
+
+from .matlab_cp2tform import get_similarity_transform_for_cv2
+
+# reference facial points, a list of coordinates (x,y)
+REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278],
+ [33.54930115, 92.3655014], [62.72990036, 92.20410156]]
+
+DEFAULT_CROP_SIZE = (96, 112)
+
+
+class FaceWarpException(Exception):
+
+ def __str__(self):
+ return 'In File {}:{}'.format(__file__, super.__str__(self))
+
+
+def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
+ """
+ Function:
+ ----------
+ get reference 5 key points according to crop settings:
+ 0. Set default crop_size:
+ if default_square:
+ crop_size = (112, 112)
+ else:
+ crop_size = (96, 112)
+ 1. Pad the crop_size by inner_padding_factor in each side;
+ 2. Resize crop_size into (output_size - outer_padding*2),
+ pad into output_size with outer_padding;
+ 3. Output reference_5point;
+ Parameters:
+ ----------
+ @output_size: (w, h) or None
+ size of aligned face image
+ @inner_padding_factor: (w_factor, h_factor)
+ padding factor for inner (w, h)
+ @outer_padding: (w_pad, h_pad)
+ each row is a pair of coordinates (x, y)
+ @default_square: True or False
+ if True:
+ default crop_size = (112, 112)
+ else:
+ default crop_size = (96, 112);
+ !!! make sure, if output_size is not None:
+ (output_size - outer_padding)
+ = some_scale * (default crop_size * (1.0 +
+ inner_padding_factor))
+ Returns:
+ ----------
+ @reference_5point: 5x2 np.array
+ each row is a pair of transformed coordinates (x, y)
+ """
+
+ tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
+ tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
+
+ # 0) make the inner region a square
+ if default_square:
+ size_diff = max(tmp_crop_size) - tmp_crop_size
+ tmp_5pts += size_diff / 2
+ tmp_crop_size += size_diff
+
+ if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
+
+ return tmp_5pts
+
+ if (inner_padding_factor == 0 and outer_padding == (0, 0)):
+ if output_size is None:
+ return tmp_5pts
+ else:
+ raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
+
+ # check output size
+ if not (0 <= inner_padding_factor <= 1.0):
+ raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
+
+ if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
+ output_size = tmp_crop_size * \
+ (1 + inner_padding_factor * 2).astype(np.int32)
+ output_size += np.array(outer_padding)
+ if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
+ raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])')
+
+ # 1) pad the inner region according inner_padding_factor
+ if inner_padding_factor > 0:
+ size_diff = tmp_crop_size * inner_padding_factor * 2
+ tmp_5pts += size_diff / 2
+ tmp_crop_size += np.round(size_diff).astype(np.int32)
+
+ # 2) resize the padded inner region
+ size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
+
+ if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
+ raise FaceWarpException('Must have (output_size - outer_padding)'
+ '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
+
+ scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
+ tmp_5pts = tmp_5pts * scale_factor
+ # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
+ # tmp_5pts = tmp_5pts + size_diff / 2
+ tmp_crop_size = size_bf_outer_pad
+
+ # 3) add outer_padding to make output_size
+ reference_5point = tmp_5pts + np.array(outer_padding)
+ tmp_crop_size = output_size
+
+ return reference_5point
+
+
+def get_affine_transform_matrix(src_pts, dst_pts):
+ """
+ Function:
+ ----------
+ get affine transform matrix 'tfm' from src_pts to dst_pts
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points matrix, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points matrix, each row is a pair of coordinates (x, y)
+ Returns:
+ ----------
+ @tfm: 2x3 np.array
+ transform matrix from src_pts to dst_pts
+ """
+
+ tfm = np.float32([[1, 0, 0], [0, 1, 0]])
+ n_pts = src_pts.shape[0]
+ ones = np.ones((n_pts, 1), src_pts.dtype)
+ src_pts_ = np.hstack([src_pts, ones])
+ dst_pts_ = np.hstack([dst_pts, ones])
+
+ A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
+
+ if rank == 3:
+ tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
+ elif rank == 2:
+ tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
+
+ return tfm
+
+
+def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'):
+ """
+ Function:
+ ----------
+ apply affine transform 'trans' to uv
+ Parameters:
+ ----------
+ @src_img: 3x3 np.array
+ input image
+ @facial_pts: could be
+ 1)a list of K coordinates (x,y)
+ or
+ 2) Kx2 or 2xK np.array
+ each row or col is a pair of coordinates (x, y)
+ @reference_pts: could be
+ 1) a list of K coordinates (x,y)
+ or
+ 2) Kx2 or 2xK np.array
+ each row or col is a pair of coordinates (x, y)
+ or
+ 3) None
+ if None, use default reference facial points
+ @crop_size: (w, h)
+ output face image size
+ @align_type: transform type, could be one of
+ 1) 'similarity': use similarity transform
+ 2) 'cv2_affine': use the first 3 points to do affine transform,
+ by calling cv2.getAffineTransform()
+ 3) 'affine': use all points to do affine transform
+ Returns:
+ ----------
+ @face_img: output face image with size (w, h) = @crop_size
+ """
+
+ if reference_pts is None:
+ if crop_size[0] == 96 and crop_size[1] == 112:
+ reference_pts = REFERENCE_FACIAL_POINTS
+ else:
+ default_square = False
+ inner_padding_factor = 0
+ outer_padding = (0, 0)
+ output_size = crop_size
+
+ reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding,
+ default_square)
+
+ ref_pts = np.float32(reference_pts)
+ ref_pts_shp = ref_pts.shape
+ if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
+ raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
+
+ if ref_pts_shp[0] == 2:
+ ref_pts = ref_pts.T
+
+ src_pts = np.float32(facial_pts)
+ src_pts_shp = src_pts.shape
+ if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
+ raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
+
+ if src_pts_shp[0] == 2:
+ src_pts = src_pts.T
+
+ if src_pts.shape != ref_pts.shape:
+ raise FaceWarpException('facial_pts and reference_pts must have the same shape')
+
+ if align_type == 'cv2_affine':
+ tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
+ elif align_type == 'affine':
+ tfm = get_affine_transform_matrix(src_pts, ref_pts)
+ else:
+ tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
+
+ face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
+
+ return face_img
diff --git a/repositories/CodeFormer/facelib/detection/matlab_cp2tform.py b/repositories/CodeFormer/facelib/detection/matlab_cp2tform.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2a8b54a91709c71437e15c68d3be9a9b0a20a34
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/matlab_cp2tform.py
@@ -0,0 +1,317 @@
+import numpy as np
+from numpy.linalg import inv, lstsq
+from numpy.linalg import matrix_rank as rank
+from numpy.linalg import norm
+
+
+class MatlabCp2tormException(Exception):
+
+ def __str__(self):
+ return 'In File {}:{}'.format(__file__, super.__str__(self))
+
+
+def tformfwd(trans, uv):
+ """
+ Function:
+ ----------
+ apply affine transform 'trans' to uv
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix
+ @uv: Kx2 np.array
+ each row is a pair of coordinates (x, y)
+
+ Returns:
+ ----------
+ @xy: Kx2 np.array
+ each row is a pair of transformed coordinates (x, y)
+ """
+ uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
+ xy = np.dot(uv, trans)
+ xy = xy[:, 0:-1]
+ return xy
+
+
+def tforminv(trans, uv):
+ """
+ Function:
+ ----------
+ apply the inverse of affine transform 'trans' to uv
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix
+ @uv: Kx2 np.array
+ each row is a pair of coordinates (x, y)
+
+ Returns:
+ ----------
+ @xy: Kx2 np.array
+ each row is a pair of inverse-transformed coordinates (x, y)
+ """
+ Tinv = inv(trans)
+ xy = tformfwd(Tinv, uv)
+ return xy
+
+
+def findNonreflectiveSimilarity(uv, xy, options=None):
+ options = {'K': 2}
+
+ K = options['K']
+ M = xy.shape[0]
+ x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
+ y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
+
+ tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
+ tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
+ X = np.vstack((tmp1, tmp2))
+
+ u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
+ v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
+ U = np.vstack((u, v))
+
+ # We know that X * r = U
+ if rank(X) >= 2 * K:
+ r, _, _, _ = lstsq(X, U, rcond=-1)
+ r = np.squeeze(r)
+ else:
+ raise Exception('cp2tform:twoUniquePointsReq')
+ sc = r[0]
+ ss = r[1]
+ tx = r[2]
+ ty = r[3]
+
+ Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
+ T = inv(Tinv)
+ T[:, 2] = np.array([0, 0, 1])
+
+ return T, Tinv
+
+
+def findSimilarity(uv, xy, options=None):
+ options = {'K': 2}
+
+ # uv = np.array(uv)
+ # xy = np.array(xy)
+
+ # Solve for trans1
+ trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
+
+ # Solve for trans2
+
+ # manually reflect the xy data across the Y-axis
+ xyR = xy
+ xyR[:, 0] = -1 * xyR[:, 0]
+
+ trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
+
+ # manually reflect the tform to undo the reflection done on xyR
+ TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+ trans2 = np.dot(trans2r, TreflectY)
+
+ # Figure out if trans1 or trans2 is better
+ xy1 = tformfwd(trans1, uv)
+ norm1 = norm(xy1 - xy)
+
+ xy2 = tformfwd(trans2, uv)
+ norm2 = norm(xy2 - xy)
+
+ if norm1 <= norm2:
+ return trans1, trans1_inv
+ else:
+ trans2_inv = inv(trans2)
+ return trans2, trans2_inv
+
+
+def get_similarity_transform(src_pts, dst_pts, reflective=True):
+ """
+ Function:
+ ----------
+ Find Similarity Transform Matrix 'trans':
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y, 1] = [u, v, 1] * trans
+
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points, each row is a pair of transformed
+ coordinates (x, y)
+ @reflective: True or False
+ if True:
+ use reflective similarity transform
+ else:
+ use non-reflective similarity transform
+
+ Returns:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix from uv to xy
+ trans_inv: 3x3 np.array
+ inverse of trans, transform matrix from xy to uv
+ """
+
+ if reflective:
+ trans, trans_inv = findSimilarity(src_pts, dst_pts)
+ else:
+ trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
+
+ return trans, trans_inv
+
+
+def cvt_tform_mat_for_cv2(trans):
+ """
+ Function:
+ ----------
+ Convert Transform Matrix 'trans' into 'cv2_trans' which could be
+ directly used by cv2.warpAffine():
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y].T = cv_trans * [u, v, 1].T
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix from uv to xy
+
+ Returns:
+ ----------
+ @cv2_trans: 2x3 np.array
+ transform matrix from src_pts to dst_pts, could be directly used
+ for cv2.warpAffine()
+ """
+ cv2_trans = trans[:, 0:2].T
+
+ return cv2_trans
+
+
+def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
+ """
+ Function:
+ ----------
+ Find Similarity Transform Matrix 'cv2_trans' which could be
+ directly used by cv2.warpAffine():
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y].T = cv_trans * [u, v, 1].T
+
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points, each row is a pair of transformed
+ coordinates (x, y)
+ reflective: True or False
+ if True:
+ use reflective similarity transform
+ else:
+ use non-reflective similarity transform
+
+ Returns:
+ ----------
+ @cv2_trans: 2x3 np.array
+ transform matrix from src_pts to dst_pts, could be directly used
+ for cv2.warpAffine()
+ """
+ trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
+ cv2_trans = cvt_tform_mat_for_cv2(trans)
+
+ return cv2_trans
+
+
+if __name__ == '__main__':
+ """
+ u = [0, 6, -2]
+ v = [0, 3, 5]
+ x = [-1, 0, 4]
+ y = [-1, -10, 4]
+
+ # In Matlab, run:
+ #
+ # uv = [u'; v'];
+ # xy = [x'; y'];
+ # tform_sim=cp2tform(uv,xy,'similarity');
+ #
+ # trans = tform_sim.tdata.T
+ # ans =
+ # -0.0764 -1.6190 0
+ # 1.6190 -0.0764 0
+ # -3.2156 0.0290 1.0000
+ # trans_inv = tform_sim.tdata.Tinv
+ # ans =
+ #
+ # -0.0291 0.6163 0
+ # -0.6163 -0.0291 0
+ # -0.0756 1.9826 1.0000
+ # xy_m=tformfwd(tform_sim, u,v)
+ #
+ # xy_m =
+ #
+ # -3.2156 0.0290
+ # 1.1833 -9.9143
+ # 5.0323 2.8853
+ # uv_m=tforminv(tform_sim, x,y)
+ #
+ # uv_m =
+ #
+ # 0.5698 1.3953
+ # 6.0872 2.2733
+ # -2.6570 4.3314
+ """
+ u = [0, 6, -2]
+ v = [0, 3, 5]
+ x = [-1, 0, 4]
+ y = [-1, -10, 4]
+
+ uv = np.array((u, v)).T
+ xy = np.array((x, y)).T
+
+ print('\n--->uv:')
+ print(uv)
+ print('\n--->xy:')
+ print(xy)
+
+ trans, trans_inv = get_similarity_transform(uv, xy)
+
+ print('\n--->trans matrix:')
+ print(trans)
+
+ print('\n--->trans_inv matrix:')
+ print(trans_inv)
+
+ print('\n---> apply transform to uv')
+ print('\nxy_m = uv_augmented * trans')
+ uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
+ xy_m = np.dot(uv_aug, trans)
+ print(xy_m)
+
+ print('\nxy_m = tformfwd(trans, uv)')
+ xy_m = tformfwd(trans, uv)
+ print(xy_m)
+
+ print('\n---> apply inverse transform to xy')
+ print('\nuv_m = xy_augmented * trans_inv')
+ xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
+ uv_m = np.dot(xy_aug, trans_inv)
+ print(uv_m)
+
+ print('\nuv_m = tformfwd(trans_inv, xy)')
+ uv_m = tformfwd(trans_inv, xy)
+ print(uv_m)
+
+ uv_m = tforminv(trans, xy)
+ print('\nuv_m = tforminv(trans, xy)')
+ print(uv_m)
diff --git a/repositories/CodeFormer/facelib/detection/retinaface/retinaface.py b/repositories/CodeFormer/facelib/detection/retinaface/retinaface.py
new file mode 100644
index 0000000000000000000000000000000000000000..02593556d88a90232bbe55a062875f4af4520621
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/retinaface/retinaface.py
@@ -0,0 +1,370 @@
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
+
+from facelib.detection.align_trans import get_reference_facial_points, warp_and_crop_face
+from facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head
+from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
+ py_cpu_nms)
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def generate_config(network_name):
+
+ cfg_mnet = {
+ 'name': 'mobilenet0.25',
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
+ 'steps': [8, 16, 32],
+ 'variance': [0.1, 0.2],
+ 'clip': False,
+ 'loc_weight': 2.0,
+ 'gpu_train': True,
+ 'batch_size': 32,
+ 'ngpu': 1,
+ 'epoch': 250,
+ 'decay1': 190,
+ 'decay2': 220,
+ 'image_size': 640,
+ 'return_layers': {
+ 'stage1': 1,
+ 'stage2': 2,
+ 'stage3': 3
+ },
+ 'in_channel': 32,
+ 'out_channel': 64
+ }
+
+ cfg_re50 = {
+ 'name': 'Resnet50',
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
+ 'steps': [8, 16, 32],
+ 'variance': [0.1, 0.2],
+ 'clip': False,
+ 'loc_weight': 2.0,
+ 'gpu_train': True,
+ 'batch_size': 24,
+ 'ngpu': 4,
+ 'epoch': 100,
+ 'decay1': 70,
+ 'decay2': 90,
+ 'image_size': 840,
+ 'return_layers': {
+ 'layer2': 1,
+ 'layer3': 2,
+ 'layer4': 3
+ },
+ 'in_channel': 256,
+ 'out_channel': 256
+ }
+
+ if network_name == 'mobile0.25':
+ return cfg_mnet
+ elif network_name == 'resnet50':
+ return cfg_re50
+ else:
+ raise NotImplementedError(f'network_name={network_name}')
+
+
+class RetinaFace(nn.Module):
+
+ def __init__(self, network_name='resnet50', half=False, phase='test'):
+ super(RetinaFace, self).__init__()
+ self.half_inference = half
+ cfg = generate_config(network_name)
+ self.backbone = cfg['name']
+
+ self.model_name = f'retinaface_{network_name}'
+ self.cfg = cfg
+ self.phase = phase
+ self.target_size, self.max_size = 1600, 2150
+ self.resize, self.scale, self.scale1 = 1., None, None
+ self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device)
+ self.reference = get_reference_facial_points(default_square=True)
+ # Build network.
+ backbone = None
+ if cfg['name'] == 'mobilenet0.25':
+ backbone = MobileNetV1()
+ self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
+ elif cfg['name'] == 'Resnet50':
+ import torchvision.models as models
+ backbone = models.resnet50(pretrained=False)
+ self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
+
+ in_channels_stage2 = cfg['in_channel']
+ in_channels_list = [
+ in_channels_stage2 * 2,
+ in_channels_stage2 * 4,
+ in_channels_stage2 * 8,
+ ]
+
+ out_channels = cfg['out_channel']
+ self.fpn = FPN(in_channels_list, out_channels)
+ self.ssh1 = SSH(out_channels, out_channels)
+ self.ssh2 = SSH(out_channels, out_channels)
+ self.ssh3 = SSH(out_channels, out_channels)
+
+ self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
+ self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
+ self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
+
+ self.to(device)
+ self.eval()
+ if self.half_inference:
+ self.half()
+
+ def forward(self, inputs):
+ out = self.body(inputs)
+
+ if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50':
+ out = list(out.values())
+ # FPN
+ fpn = self.fpn(out)
+
+ # SSH
+ feature1 = self.ssh1(fpn[0])
+ feature2 = self.ssh2(fpn[1])
+ feature3 = self.ssh3(fpn[2])
+ features = [feature1, feature2, feature3]
+
+ bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
+ classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
+ tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
+ ldm_regressions = (torch.cat(tmp, dim=1))
+
+ if self.phase == 'train':
+ output = (bbox_regressions, classifications, ldm_regressions)
+ else:
+ output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
+ return output
+
+ def __detect_faces(self, inputs):
+ # get scale
+ height, width = inputs.shape[2:]
+ self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device)
+ tmp = [width, height, width, height, width, height, width, height, width, height]
+ self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device)
+
+ # forawrd
+ inputs = inputs.to(device)
+ if self.half_inference:
+ inputs = inputs.half()
+ loc, conf, landmarks = self(inputs)
+
+ # get priorbox
+ priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
+ priors = priorbox.forward().to(device)
+
+ return loc, conf, landmarks, priors
+
+ # single image detection
+ def transform(self, image, use_origin_size):
+ # convert to opencv format
+ if isinstance(image, Image.Image):
+ image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+ image = image.astype(np.float32)
+
+ # testing scale
+ im_size_min = np.min(image.shape[0:2])
+ im_size_max = np.max(image.shape[0:2])
+ resize = float(self.target_size) / float(im_size_min)
+
+ # prevent bigger axis from being more than max_size
+ if np.round(resize * im_size_max) > self.max_size:
+ resize = float(self.max_size) / float(im_size_max)
+ resize = 1 if use_origin_size else resize
+
+ # resize
+ if resize != 1:
+ image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
+
+ # convert to torch.tensor format
+ # image -= (104, 117, 123)
+ image = image.transpose(2, 0, 1)
+ image = torch.from_numpy(image).unsqueeze(0)
+
+ return image, resize
+
+ def detect_faces(
+ self,
+ image,
+ conf_threshold=0.8,
+ nms_threshold=0.4,
+ use_origin_size=True,
+ ):
+ """
+ Params:
+ imgs: BGR image
+ """
+ image, self.resize = self.transform(image, use_origin_size)
+ image = image.to(device)
+ if self.half_inference:
+ image = image.half()
+ image = image - self.mean_tensor
+
+ loc, conf, landmarks, priors = self.__detect_faces(image)
+
+ boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance'])
+ boxes = boxes * self.scale / self.resize
+ boxes = boxes.cpu().numpy()
+
+ scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
+
+ landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance'])
+ landmarks = landmarks * self.scale1 / self.resize
+ landmarks = landmarks.cpu().numpy()
+
+ # ignore low scores
+ inds = np.where(scores > conf_threshold)[0]
+ boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
+
+ # sort
+ order = scores.argsort()[::-1]
+ boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
+
+ # do NMS
+ bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
+ keep = py_cpu_nms(bounding_boxes, nms_threshold)
+ bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
+ # self.t['forward_pass'].toc()
+ # print(self.t['forward_pass'].average_time)
+ # import sys
+ # sys.stdout.flush()
+ return np.concatenate((bounding_boxes, landmarks), axis=1)
+
+ def __align_multi(self, image, boxes, landmarks, limit=None):
+
+ if len(boxes) < 1:
+ return [], []
+
+ if limit:
+ boxes = boxes[:limit]
+ landmarks = landmarks[:limit]
+
+ faces = []
+ for landmark in landmarks:
+ facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
+
+ warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112))
+ faces.append(warped_face)
+
+ return np.concatenate((boxes, landmarks), axis=1), faces
+
+ def align_multi(self, img, conf_threshold=0.8, limit=None):
+
+ rlt = self.detect_faces(img, conf_threshold=conf_threshold)
+ boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
+
+ return self.__align_multi(img, boxes, landmarks, limit)
+
+ # batched detection
+ def batched_transform(self, frames, use_origin_size):
+ """
+ Arguments:
+ frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
+ type=np.float32, BGR format).
+ use_origin_size: whether to use origin size.
+ """
+ from_PIL = True if isinstance(frames[0], Image.Image) else False
+
+ # convert to opencv format
+ if from_PIL:
+ frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames]
+ frames = np.asarray(frames, dtype=np.float32)
+
+ # testing scale
+ im_size_min = np.min(frames[0].shape[0:2])
+ im_size_max = np.max(frames[0].shape[0:2])
+ resize = float(self.target_size) / float(im_size_min)
+
+ # prevent bigger axis from being more than max_size
+ if np.round(resize * im_size_max) > self.max_size:
+ resize = float(self.max_size) / float(im_size_max)
+ resize = 1 if use_origin_size else resize
+
+ # resize
+ if resize != 1:
+ if not from_PIL:
+ frames = F.interpolate(frames, scale_factor=resize)
+ else:
+ frames = [
+ cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
+ for frame in frames
+ ]
+
+ # convert to torch.tensor format
+ if not from_PIL:
+ frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
+ else:
+ frames = frames.transpose((0, 3, 1, 2))
+ frames = torch.from_numpy(frames)
+
+ return frames, resize
+
+ def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True):
+ """
+ Arguments:
+ frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
+ type=np.uint8, BGR format).
+ conf_threshold: confidence threshold.
+ nms_threshold: nms threshold.
+ use_origin_size: whether to use origin size.
+ Returns:
+ final_bounding_boxes: list of np.array ([n_boxes, 5],
+ type=np.float32).
+ final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
+ """
+ # self.t['forward_pass'].tic()
+ frames, self.resize = self.batched_transform(frames, use_origin_size)
+ frames = frames.to(device)
+ frames = frames - self.mean_tensor
+
+ b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
+
+ final_bounding_boxes, final_landmarks = [], []
+
+ # decode
+ priors = priors.unsqueeze(0)
+ b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize
+ b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize
+ b_conf = b_conf[:, :, 1]
+
+ # index for selection
+ b_indice = b_conf > conf_threshold
+
+ # concat
+ b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
+
+ for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
+
+ # ignore low scores
+ pred, landm = pred[inds, :], landm[inds, :]
+ if pred.shape[0] == 0:
+ final_bounding_boxes.append(np.array([], dtype=np.float32))
+ final_landmarks.append(np.array([], dtype=np.float32))
+ continue
+
+ # sort
+ # order = score.argsort(descending=True)
+ # box, landm, score = box[order], landm[order], score[order]
+
+ # to CPU
+ bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
+
+ # NMS
+ keep = py_cpu_nms(bounding_boxes, nms_threshold)
+ bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
+
+ # append
+ final_bounding_boxes.append(bounding_boxes)
+ final_landmarks.append(landmarks)
+ # self.t['forward_pass'].toc(average=True)
+ # self.batch_time += self.t['forward_pass'].diff
+ # self.total_frame += len(frames)
+ # print(self.batch_time / self.total_frame)
+
+ return final_bounding_boxes, final_landmarks
diff --git a/repositories/CodeFormer/facelib/detection/retinaface/retinaface_net.py b/repositories/CodeFormer/facelib/detection/retinaface/retinaface_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab6aa82d3e9055a838f1f9076b12f05fdfc154d0
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/retinaface/retinaface_net.py
@@ -0,0 +1,196 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv_bn(inp, oup, stride=1, leaky=0):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True))
+
+
+def conv_bn_no_relu(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+
+
+def conv_bn1X1(inp, oup, stride, leaky=0):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True))
+
+
+def conv_dw(inp, oup, stride, leaky=0.1):
+ return nn.Sequential(
+ nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
+ nn.BatchNorm2d(inp),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
+ )
+
+
+class SSH(nn.Module):
+
+ def __init__(self, in_channel, out_channel):
+ super(SSH, self).__init__()
+ assert out_channel % 4 == 0
+ leaky = 0
+ if (out_channel <= 64):
+ leaky = 0.1
+ self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
+
+ self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
+ self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
+
+ self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
+ self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
+
+ def forward(self, input):
+ conv3X3 = self.conv3X3(input)
+
+ conv5X5_1 = self.conv5X5_1(input)
+ conv5X5 = self.conv5X5_2(conv5X5_1)
+
+ conv7X7_2 = self.conv7X7_2(conv5X5_1)
+ conv7X7 = self.conv7x7_3(conv7X7_2)
+
+ out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
+ out = F.relu(out)
+ return out
+
+
+class FPN(nn.Module):
+
+ def __init__(self, in_channels_list, out_channels):
+ super(FPN, self).__init__()
+ leaky = 0
+ if (out_channels <= 64):
+ leaky = 0.1
+ self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
+ self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
+ self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
+
+ self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
+ self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
+
+ def forward(self, input):
+ # names = list(input.keys())
+ # input = list(input.values())
+
+ output1 = self.output1(input[0])
+ output2 = self.output2(input[1])
+ output3 = self.output3(input[2])
+
+ up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
+ output2 = output2 + up3
+ output2 = self.merge2(output2)
+
+ up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
+ output1 = output1 + up2
+ output1 = self.merge1(output1)
+
+ out = [output1, output2, output3]
+ return out
+
+
+class MobileNetV1(nn.Module):
+
+ def __init__(self):
+ super(MobileNetV1, self).__init__()
+ self.stage1 = nn.Sequential(
+ conv_bn(3, 8, 2, leaky=0.1), # 3
+ conv_dw(8, 16, 1), # 7
+ conv_dw(16, 32, 2), # 11
+ conv_dw(32, 32, 1), # 19
+ conv_dw(32, 64, 2), # 27
+ conv_dw(64, 64, 1), # 43
+ )
+ self.stage2 = nn.Sequential(
+ conv_dw(64, 128, 2), # 43 + 16 = 59
+ conv_dw(128, 128, 1), # 59 + 32 = 91
+ conv_dw(128, 128, 1), # 91 + 32 = 123
+ conv_dw(128, 128, 1), # 123 + 32 = 155
+ conv_dw(128, 128, 1), # 155 + 32 = 187
+ conv_dw(128, 128, 1), # 187 + 32 = 219
+ )
+ self.stage3 = nn.Sequential(
+ conv_dw(128, 256, 2), # 219 +3 2 = 241
+ conv_dw(256, 256, 1), # 241 + 64 = 301
+ )
+ self.avg = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(256, 1000)
+
+ def forward(self, x):
+ x = self.stage1(x)
+ x = self.stage2(x)
+ x = self.stage3(x)
+ x = self.avg(x)
+ # x = self.model(x)
+ x = x.view(-1, 256)
+ x = self.fc(x)
+ return x
+
+
+class ClassHead(nn.Module):
+
+ def __init__(self, inchannels=512, num_anchors=3):
+ super(ClassHead, self).__init__()
+ self.num_anchors = num_anchors
+ self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
+
+ def forward(self, x):
+ out = self.conv1x1(x)
+ out = out.permute(0, 2, 3, 1).contiguous()
+
+ return out.view(out.shape[0], -1, 2)
+
+
+class BboxHead(nn.Module):
+
+ def __init__(self, inchannels=512, num_anchors=3):
+ super(BboxHead, self).__init__()
+ self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
+
+ def forward(self, x):
+ out = self.conv1x1(x)
+ out = out.permute(0, 2, 3, 1).contiguous()
+
+ return out.view(out.shape[0], -1, 4)
+
+
+class LandmarkHead(nn.Module):
+
+ def __init__(self, inchannels=512, num_anchors=3):
+ super(LandmarkHead, self).__init__()
+ self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
+
+ def forward(self, x):
+ out = self.conv1x1(x)
+ out = out.permute(0, 2, 3, 1).contiguous()
+
+ return out.view(out.shape[0], -1, 10)
+
+
+def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
+ classhead = nn.ModuleList()
+ for i in range(fpn_num):
+ classhead.append(ClassHead(inchannels, anchor_num))
+ return classhead
+
+
+def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
+ bboxhead = nn.ModuleList()
+ for i in range(fpn_num):
+ bboxhead.append(BboxHead(inchannels, anchor_num))
+ return bboxhead
+
+
+def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
+ landmarkhead = nn.ModuleList()
+ for i in range(fpn_num):
+ landmarkhead.append(LandmarkHead(inchannels, anchor_num))
+ return landmarkhead
diff --git a/repositories/CodeFormer/facelib/detection/retinaface/retinaface_utils.py b/repositories/CodeFormer/facelib/detection/retinaface/retinaface_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c357757741c6d9bd7ce4d8ce740fefd51850fbf
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/retinaface/retinaface_utils.py
@@ -0,0 +1,421 @@
+import numpy as np
+import torch
+import torchvision
+from itertools import product as product
+from math import ceil
+
+
+class PriorBox(object):
+
+ def __init__(self, cfg, image_size=None, phase='train'):
+ super(PriorBox, self).__init__()
+ self.min_sizes = cfg['min_sizes']
+ self.steps = cfg['steps']
+ self.clip = cfg['clip']
+ self.image_size = image_size
+ self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
+ self.name = 's'
+
+ def forward(self):
+ anchors = []
+ for k, f in enumerate(self.feature_maps):
+ min_sizes = self.min_sizes[k]
+ for i, j in product(range(f[0]), range(f[1])):
+ for min_size in min_sizes:
+ s_kx = min_size / self.image_size[1]
+ s_ky = min_size / self.image_size[0]
+ dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
+ dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
+ for cy, cx in product(dense_cy, dense_cx):
+ anchors += [cx, cy, s_kx, s_ky]
+
+ # back to torch land
+ output = torch.Tensor(anchors).view(-1, 4)
+ if self.clip:
+ output.clamp_(max=1, min=0)
+ return output
+
+
+def py_cpu_nms(dets, thresh):
+ """Pure Python NMS baseline."""
+ keep = torchvision.ops.nms(
+ boxes=torch.Tensor(dets[:, :4]),
+ scores=torch.Tensor(dets[:, 4]),
+ iou_threshold=thresh,
+ )
+
+ return list(keep)
+
+
+def point_form(boxes):
+ """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
+ representation for comparison to point form ground truth data.
+ Args:
+ boxes: (tensor) center-size default boxes from priorbox layers.
+ Return:
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+ """
+ return torch.cat(
+ (
+ boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin
+ boxes[:, :2] + boxes[:, 2:] / 2),
+ 1) # xmax, ymax
+
+
+def center_size(boxes):
+ """ Convert prior_boxes to (cx, cy, w, h)
+ representation for comparison to center-size form ground truth data.
+ Args:
+ boxes: (tensor) point_form boxes
+ Return:
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+ """
+ return torch.cat(
+ (boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy
+ boxes[:, 2:] - boxes[:, :2],
+ 1) # w, h
+
+
+def intersect(box_a, box_b):
+ """ We resize both tensors to [A,B,2] without new malloc:
+ [A,2] -> [A,1,2] -> [A,B,2]
+ [B,2] -> [1,B,2] -> [A,B,2]
+ Then we compute the area of intersect between box_a and box_b.
+ Args:
+ box_a: (tensor) bounding boxes, Shape: [A,4].
+ box_b: (tensor) bounding boxes, Shape: [B,4].
+ Return:
+ (tensor) intersection area, Shape: [A,B].
+ """
+ A = box_a.size(0)
+ B = box_b.size(0)
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
+ inter = torch.clamp((max_xy - min_xy), min=0)
+ return inter[:, :, 0] * inter[:, :, 1]
+
+
+def jaccard(box_a, box_b):
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
+ is simply the intersection over union of two boxes. Here we operate on
+ ground truth boxes and default boxes.
+ E.g.:
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
+ Args:
+ box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
+ box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
+ Return:
+ jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
+ """
+ inter = intersect(box_a, box_b)
+ area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
+ area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
+ union = area_a + area_b - inter
+ return inter / union # [A,B]
+
+
+def matrix_iou(a, b):
+ """
+ return iou of a and b, numpy version for data augenmentation
+ """
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+ return area_i / (area_a[:, np.newaxis] + area_b - area_i)
+
+
+def matrix_iof(a, b):
+ """
+ return iof of a and b, numpy version for data augenmentation
+ """
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ return area_i / np.maximum(area_a[:, np.newaxis], 1)
+
+
+def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
+ """Match each prior box with the ground truth box of the highest jaccard
+ overlap, encode the bounding boxes, then return the matched indices
+ corresponding to both confidence and location preds.
+ Args:
+ threshold: (float) The overlap threshold used when matching boxes.
+ truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
+ priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
+ variances: (tensor) Variances corresponding to each prior coord,
+ Shape: [num_priors, 4].
+ labels: (tensor) All the class labels for the image, Shape: [num_obj].
+ landms: (tensor) Ground truth landms, Shape [num_obj, 10].
+ loc_t: (tensor) Tensor to be filled w/ encoded location targets.
+ conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
+ landm_t: (tensor) Tensor to be filled w/ encoded landm targets.
+ idx: (int) current batch index
+ Return:
+ The matched indices corresponding to 1)location 2)confidence
+ 3)landm preds.
+ """
+ # jaccard index
+ overlaps = jaccard(truths, point_form(priors))
+ # (Bipartite Matching)
+ # [1,num_objects] best prior for each ground truth
+ best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
+
+ # ignore hard gt
+ valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
+ best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
+ if best_prior_idx_filter.shape[0] <= 0:
+ loc_t[idx] = 0
+ conf_t[idx] = 0
+ return
+
+ # [1,num_priors] best ground truth for each prior
+ best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
+ best_truth_idx.squeeze_(0)
+ best_truth_overlap.squeeze_(0)
+ best_prior_idx.squeeze_(1)
+ best_prior_idx_filter.squeeze_(1)
+ best_prior_overlap.squeeze_(1)
+ best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
+ # TODO refactor: index best_prior_idx with long tensor
+ # ensure every gt matches with its prior of max overlap
+ for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
+ best_truth_idx[best_prior_idx[j]] = j
+ matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
+ conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
+ conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
+ loc = encode(matches, priors, variances)
+
+ matches_landm = landms[best_truth_idx]
+ landm = encode_landm(matches_landm, priors, variances)
+ loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
+ conf_t[idx] = conf # [num_priors] top class label for each prior
+ landm_t[idx] = landm
+
+
+def encode(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 4].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded boxes (tensor), Shape: [num_priors, 4]
+ """
+
+ # dist b/t match center and prior's center
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
+ # encode variance
+ g_cxcy /= (variances[0] * priors[:, 2:])
+ # match wh / prior wh
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
+ g_wh = torch.log(g_wh) / variances[1]
+ # return target for smooth_l1_loss
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
+
+
+def encode_landm(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 10].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded landm (tensor), Shape: [num_priors, 10]
+ """
+
+ # dist b/t match center and prior's center
+ matched = torch.reshape(matched, (matched.size(0), 5, 2))
+ priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
+ g_cxcy = matched[:, :, :2] - priors[:, :, :2]
+ # encode variance
+ g_cxcy /= (variances[0] * priors[:, :, 2:])
+ # g_cxcy /= priors[:, :, 2:]
+ g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
+ # return target for smooth_l1_loss
+ return g_cxcy
+
+
+# Adapted from https://github.com/Hakuyume/chainer-ssd
+def decode(loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ loc (tensor): location predictions for loc layers,
+ Shape: [num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+
+ boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
+ boxes[:, :2] -= boxes[:, 2:] / 2
+ boxes[:, 2:] += boxes[:, :2]
+ return boxes
+
+
+def decode_landm(pre, priors, variances):
+ """Decode landm from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ pre (tensor): landm predictions for loc layers,
+ Shape: [num_priors,10]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded landm predictions
+ """
+ tmp = (
+ priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
+ )
+ landms = torch.cat(tmp, dim=1)
+ return landms
+
+
+def batched_decode(b_loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ b_loc (tensor): location predictions for loc layers,
+ Shape: [num_batches,num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [1,num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+ boxes = (
+ priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:],
+ priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]),
+ )
+ boxes = torch.cat(boxes, dim=2)
+
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
+ boxes[:, :, 2:] += boxes[:, :, :2]
+ return boxes
+
+
+def batched_decode_landm(pre, priors, variances):
+ """Decode landm from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ pre (tensor): landm predictions for loc layers,
+ Shape: [num_batches,num_priors,10]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [1,num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded landm predictions
+ """
+ landms = (
+ priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:],
+ )
+ landms = torch.cat(landms, dim=2)
+ return landms
+
+
+def log_sum_exp(x):
+ """Utility function for computing log_sum_exp while determining
+ This will be used to determine unaveraged confidence loss across
+ all examples in a batch.
+ Args:
+ x (Variable(tensor)): conf_preds from conf layers
+ """
+ x_max = x.data.max()
+ return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
+
+
+# Original author: Francisco Massa:
+# https://github.com/fmassa/object-detection.torch
+# Ported to PyTorch by Max deGroot (02/01/2017)
+def nms(boxes, scores, overlap=0.5, top_k=200):
+ """Apply non-maximum suppression at test time to avoid detecting too many
+ overlapping bounding boxes for a given object.
+ Args:
+ boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
+ scores: (tensor) The class predscores for the img, Shape:[num_priors].
+ overlap: (float) The overlap thresh for suppressing unnecessary boxes.
+ top_k: (int) The Maximum number of box preds to consider.
+ Return:
+ The indices of the kept boxes with respect to num_priors.
+ """
+
+ keep = torch.Tensor(scores.size(0)).fill_(0).long()
+ if boxes.numel() == 0:
+ return keep
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+ area = torch.mul(x2 - x1, y2 - y1)
+ v, idx = scores.sort(0) # sort in ascending order
+ # I = I[v >= 0.01]
+ idx = idx[-top_k:] # indices of the top-k largest vals
+ xx1 = boxes.new()
+ yy1 = boxes.new()
+ xx2 = boxes.new()
+ yy2 = boxes.new()
+ w = boxes.new()
+ h = boxes.new()
+
+ # keep = torch.Tensor()
+ count = 0
+ while idx.numel() > 0:
+ i = idx[-1] # index of current largest val
+ # keep.append(i)
+ keep[count] = i
+ count += 1
+ if idx.size(0) == 1:
+ break
+ idx = idx[:-1] # remove kept element from view
+ # load bboxes of next highest vals
+ torch.index_select(x1, 0, idx, out=xx1)
+ torch.index_select(y1, 0, idx, out=yy1)
+ torch.index_select(x2, 0, idx, out=xx2)
+ torch.index_select(y2, 0, idx, out=yy2)
+ # store element-wise max with next highest score
+ xx1 = torch.clamp(xx1, min=x1[i])
+ yy1 = torch.clamp(yy1, min=y1[i])
+ xx2 = torch.clamp(xx2, max=x2[i])
+ yy2 = torch.clamp(yy2, max=y2[i])
+ w.resize_as_(xx2)
+ h.resize_as_(yy2)
+ w = xx2 - xx1
+ h = yy2 - yy1
+ # check sizes of xx1 and xx2.. after each iteration
+ w = torch.clamp(w, min=0.0)
+ h = torch.clamp(h, min=0.0)
+ inter = w * h
+ # IoU = i / (area(a) + area(b) - i)
+ rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
+ union = (rem_areas - inter) + area[i]
+ IoU = inter / union # store result in iou
+ # keep only elements with an IoU <= overlap
+ idx = idx[IoU.le(overlap)]
+ return keep, count
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/__init__.py b/repositories/CodeFormer/facelib/detection/yolov5face/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/face_detector.py b/repositories/CodeFormer/facelib/detection/yolov5face/face_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..2282b283e4446915731180e1d2dff748e8e46ec2
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/yolov5face/face_detector.py
@@ -0,0 +1,142 @@
+import copy
+import os
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+from torch import nn
+
+from facelib.detection.yolov5face.models.common import Conv
+from facelib.detection.yolov5face.models.yolo import Model
+from facelib.detection.yolov5face.utils.datasets import letterbox
+from facelib.detection.yolov5face.utils.general import (
+ check_img_size,
+ non_max_suppression_face,
+ scale_coords,
+ scale_coords_landmarks,
+)
+
+IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.'))) >= (1, 9, 0)
+
+
+def isListempty(inList):
+ if isinstance(inList, list): # Is a list
+ return all(map(isListempty, inList))
+ return False # Not a list
+
+class YoloDetector:
+ def __init__(
+ self,
+ config_name,
+ min_face=10,
+ target_size=None,
+ device='cuda',
+ ):
+ """
+ config_name: name of .yaml config with network configuration from models/ folder.
+ min_face : minimal face size in pixels.
+ target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080.
+ None for original resolution.
+ """
+ self._class_path = Path(__file__).parent.absolute()
+ self.target_size = target_size
+ self.min_face = min_face
+ self.detector = Model(cfg=config_name)
+ self.device = device
+
+
+ def _preprocess(self, imgs):
+ """
+ Preprocessing image before passing through the network. Resize and conversion to torch tensor.
+ """
+ pp_imgs = []
+ for img in imgs:
+ h0, w0 = img.shape[:2] # orig hw
+ if self.target_size:
+ r = self.target_size / min(h0, w0) # resize image to img_size
+ if r < 1:
+ img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR)
+
+ imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size
+ img = letterbox(img, new_shape=imgsz)[0]
+ pp_imgs.append(img)
+ pp_imgs = np.array(pp_imgs)
+ pp_imgs = pp_imgs.transpose(0, 3, 1, 2)
+ pp_imgs = torch.from_numpy(pp_imgs).to(self.device)
+ pp_imgs = pp_imgs.float() # uint8 to fp16/32
+ return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0
+
+ def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
+ """
+ Postprocessing of raw pytorch model output.
+ Returns:
+ bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
+ points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
+ """
+ bboxes = [[] for _ in range(len(origimgs))]
+ landmarks = [[] for _ in range(len(origimgs))]
+
+ pred = non_max_suppression_face(pred, conf_thres, iou_thres)
+
+ for image_id, origimg in enumerate(origimgs):
+ img_shape = origimg.shape
+ image_height, image_width = img_shape[:2]
+ gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh
+ gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks
+ det = pred[image_id].cpu()
+ scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round()
+ scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round()
+
+ for j in range(det.size()[0]):
+ box = (det[j, :4].view(1, 4) / gn).view(-1).tolist()
+ box = list(
+ map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height])
+ )
+ if box[3] - box[1] < self.min_face:
+ continue
+ lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist()
+ lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)]))
+ lm = [lm[i : i + 2] for i in range(0, len(lm), 2)]
+ bboxes[image_id].append(box)
+ landmarks[image_id].append(lm)
+ return bboxes, landmarks
+
+ def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
+ """
+ Get bbox coordinates and keypoints of faces on original image.
+ Params:
+ imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference)
+ conf_thres: confidence threshold for each prediction
+ iou_thres: threshold for NMS (filter of intersecting bboxes)
+ Returns:
+ bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
+ points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
+ """
+ # Pass input images through face detector
+ images = imgs if isinstance(imgs, list) else [imgs]
+ images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
+ origimgs = copy.deepcopy(images)
+
+ images = self._preprocess(images)
+
+ if IS_HIGH_VERSION:
+ with torch.inference_mode(): # for pytorch>=1.9
+ pred = self.detector(images)[0]
+ else:
+ with torch.no_grad(): # for pytorch<1.9
+ pred = self.detector(images)[0]
+
+ bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres)
+
+ # return bboxes, points
+ if not isListempty(points):
+ bboxes = np.array(bboxes).reshape(-1,4)
+ points = np.array(points).reshape(-1,10)
+ padding = bboxes[:,0].reshape(-1,1)
+ return np.concatenate((bboxes, padding, points), axis=1)
+ else:
+ return None
+
+ def __call__(self, *args):
+ return self.predict(*args)
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/models/__init__.py b/repositories/CodeFormer/facelib/detection/yolov5face/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/models/common.py b/repositories/CodeFormer/facelib/detection/yolov5face/models/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..497a00444c4c59725001993a63fe4617e9d323c8
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/yolov5face/models/common.py
@@ -0,0 +1,299 @@
+# This file contains modules common to various models
+
+import math
+
+import numpy as np
+import torch
+from torch import nn
+
+from facelib.detection.yolov5face.utils.datasets import letterbox
+from facelib.detection.yolov5face.utils.general import (
+ make_divisible,
+ non_max_suppression,
+ scale_coords,
+ xyxy2xywh,
+)
+
+
+def autopad(k, p=None): # kernel, padding
+ # Pad to 'same'
+ if p is None:
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
+ return p
+
+
+def channel_shuffle(x, groups):
+ batchsize, num_channels, height, width = x.data.size()
+ channels_per_group = torch.div(num_channels, groups, rounding_mode="trunc")
+
+ # reshape
+ x = x.view(batchsize, groups, channels_per_group, height, width)
+ x = torch.transpose(x, 1, 2).contiguous()
+
+ # flatten
+ return x.view(batchsize, -1, height, width)
+
+
+def DWConv(c1, c2, k=1, s=1, act=True):
+ # Depthwise convolution
+ return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
+
+
+class Conv(nn.Module):
+ # Standard convolution
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
+ super().__init__()
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
+ self.bn = nn.BatchNorm2d(c2)
+ self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
+
+ def forward(self, x):
+ return self.act(self.bn(self.conv(x)))
+
+ def fuseforward(self, x):
+ return self.act(self.conv(x))
+
+
+class StemBlock(nn.Module):
+ def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True):
+ super().__init__()
+ self.stem_1 = Conv(c1, c2, k, s, p, g, act)
+ self.stem_2a = Conv(c2, c2 // 2, 1, 1, 0)
+ self.stem_2b = Conv(c2 // 2, c2, 3, 2, 1)
+ self.stem_2p = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
+ self.stem_3 = Conv(c2 * 2, c2, 1, 1, 0)
+
+ def forward(self, x):
+ stem_1_out = self.stem_1(x)
+ stem_2a_out = self.stem_2a(stem_1_out)
+ stem_2b_out = self.stem_2b(stem_2a_out)
+ stem_2p_out = self.stem_2p(stem_1_out)
+ return self.stem_3(torch.cat((stem_2b_out, stem_2p_out), 1))
+
+
+class Bottleneck(nn.Module):
+ # Standard bottleneck
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c_, c2, 3, 1, g=g)
+ self.add = shortcut and c1 == c2
+
+ def forward(self, x):
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class BottleneckCSP(nn.Module):
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
+ self.act = nn.LeakyReLU(0.1, inplace=True)
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+
+ def forward(self, x):
+ y1 = self.cv3(self.m(self.cv1(x)))
+ y2 = self.cv2(x)
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
+
+
+class C3(nn.Module):
+ # CSP Bottleneck with 3 convolutions
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c1, c_, 1, 1)
+ self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+
+ def forward(self, x):
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
+
+
+class ShuffleV2Block(nn.Module):
+ def __init__(self, inp, oup, stride):
+ super().__init__()
+
+ if not 1 <= stride <= 3:
+ raise ValueError("illegal stride value")
+ self.stride = stride
+
+ branch_features = oup // 2
+
+ if self.stride > 1:
+ self.branch1 = nn.Sequential(
+ self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
+ nn.BatchNorm2d(inp),
+ nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(branch_features),
+ nn.SiLU(),
+ )
+ else:
+ self.branch1 = nn.Sequential()
+
+ self.branch2 = nn.Sequential(
+ nn.Conv2d(
+ inp if (self.stride > 1) else branch_features,
+ branch_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ ),
+ nn.BatchNorm2d(branch_features),
+ nn.SiLU(),
+ self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
+ nn.BatchNorm2d(branch_features),
+ nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(branch_features),
+ nn.SiLU(),
+ )
+
+ @staticmethod
+ def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
+ return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
+
+ def forward(self, x):
+ if self.stride == 1:
+ x1, x2 = x.chunk(2, dim=1)
+ out = torch.cat((x1, self.branch2(x2)), dim=1)
+ else:
+ out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
+ out = channel_shuffle(out, 2)
+ return out
+
+
+class SPP(nn.Module):
+ # Spatial pyramid pooling layer used in YOLOv3-SPP
+ def __init__(self, c1, c2, k=(5, 9, 13)):
+ super().__init__()
+ c_ = c1 // 2 # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
+
+ def forward(self, x):
+ x = self.cv1(x)
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
+
+
+class Focus(nn.Module):
+ # Focus wh information into c-space
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
+ super().__init__()
+ self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
+
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
+ return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
+
+
+class Concat(nn.Module):
+ # Concatenate a list of tensors along dimension
+ def __init__(self, dimension=1):
+ super().__init__()
+ self.d = dimension
+
+ def forward(self, x):
+ return torch.cat(x, self.d)
+
+
+class NMS(nn.Module):
+ # Non-Maximum Suppression (NMS) module
+ conf = 0.25 # confidence threshold
+ iou = 0.45 # IoU threshold
+ classes = None # (optional list) filter by class
+
+ def forward(self, x):
+ return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
+
+
+class AutoShape(nn.Module):
+ # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
+ img_size = 640 # inference size (pixels)
+ conf = 0.25 # NMS confidence threshold
+ iou = 0.45 # NMS IoU threshold
+ classes = None # (optional list) filter by class
+
+ def __init__(self, model):
+ super().__init__()
+ self.model = model.eval()
+
+ def autoshape(self):
+ print("autoShape already enabled, skipping... ") # model already converted to model.autoshape()
+ return self
+
+ def forward(self, imgs, size=640, augment=False, profile=False):
+ # Inference from various sources. For height=720, width=1280, RGB images example inputs are:
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
+ # PIL: = Image.open('image.jpg') # HWC x(720,1280,3)
+ # numpy: = np.zeros((720,1280,3)) # HWC
+ # torch: = torch.zeros(16,3,720,1280) # BCHW
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
+
+ p = next(self.model.parameters()) # for device and type
+ if isinstance(imgs, torch.Tensor): # torch
+ return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
+
+ # Pre-process
+ n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
+ shape0, shape1 = [], [] # image and inference shapes
+ for i, im in enumerate(imgs):
+ im = np.array(im) # to numpy
+ if im.shape[0] < 5: # image in CHW
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
+ im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
+ s = im.shape[:2] # HWC
+ shape0.append(s) # image shape
+ g = size / max(s) # gain
+ shape1.append([y * g for y in s])
+ imgs[i] = im # update
+ shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
+ x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
+ x = np.stack(x, 0) if n > 1 else x[0][None] # stack
+ x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255.0 # uint8 to fp16/32
+
+ # Inference
+ with torch.no_grad():
+ y = self.model(x, augment, profile)[0] # forward
+ y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
+
+ # Post-process
+ for i in range(n):
+ scale_coords(shape1, y[i][:, :4], shape0[i])
+
+ return Detections(imgs, y, self.names)
+
+
+class Detections:
+ # detections class for YOLOv5 inference results
+ def __init__(self, imgs, pred, names=None):
+ super().__init__()
+ d = pred[0].device # device
+ gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1.0, 1.0], device=d) for im in imgs] # normalizations
+ self.imgs = imgs # list of images as numpy arrays
+ self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
+ self.names = names # class names
+ self.xyxy = pred # xyxy pixels
+ self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
+ self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
+ self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
+ self.n = len(self.pred)
+
+ def __len__(self):
+ return self.n
+
+ def tolist(self):
+ # return a list of Detections objects, i.e. 'for result in results.tolist():'
+ x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)]
+ for d in x:
+ for k in ["imgs", "pred", "xyxy", "xyxyn", "xywh", "xywhn"]:
+ setattr(d, k, getattr(d, k)[0]) # pop out of list
+ return x
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/models/experimental.py b/repositories/CodeFormer/facelib/detection/yolov5face/models/experimental.py
new file mode 100644
index 0000000000000000000000000000000000000000..37ba4c4420789c92dc0e2aaeb3d5b64859ec728c
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/yolov5face/models/experimental.py
@@ -0,0 +1,45 @@
+# # This file contains experimental modules
+
+import numpy as np
+import torch
+from torch import nn
+
+from facelib.detection.yolov5face.models.common import Conv
+
+
+class CrossConv(nn.Module):
+ # Cross Convolution Downsample
+ def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
+ # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, (1, k), (1, s))
+ self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
+ self.add = shortcut and c1 == c2
+
+ def forward(self, x):
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class MixConv2d(nn.Module):
+ # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
+ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
+ super().__init__()
+ groups = len(k)
+ if equal_ch: # equal c_ per group
+ i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices
+ c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
+ else: # equal weight.numel() per group
+ b = [c2] + [0] * groups
+ a = np.eye(groups + 1, groups, k=-1)
+ a -= np.roll(a, 1, axis=1)
+ a *= np.array(k) ** 2
+ a[0] = 1
+ c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
+
+ self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
+ self.bn = nn.BatchNorm2d(c2)
+ self.act = nn.LeakyReLU(0.1, inplace=True)
+
+ def forward(self, x):
+ return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/models/yolo.py b/repositories/CodeFormer/facelib/detection/yolov5face/models/yolo.py
new file mode 100644
index 0000000000000000000000000000000000000000..70845d972f0bcfd3632fcbac096b23e1b4d4d779
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/yolov5face/models/yolo.py
@@ -0,0 +1,235 @@
+import math
+from copy import deepcopy
+from pathlib import Path
+
+import torch
+import yaml # for torch hub
+from torch import nn
+
+from facelib.detection.yolov5face.models.common import (
+ C3,
+ NMS,
+ SPP,
+ AutoShape,
+ Bottleneck,
+ BottleneckCSP,
+ Concat,
+ Conv,
+ DWConv,
+ Focus,
+ ShuffleV2Block,
+ StemBlock,
+)
+from facelib.detection.yolov5face.models.experimental import CrossConv, MixConv2d
+from facelib.detection.yolov5face.utils.autoanchor import check_anchor_order
+from facelib.detection.yolov5face.utils.general import make_divisible
+from facelib.detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn
+
+
+class Detect(nn.Module):
+ stride = None # strides computed during build
+ export = False # onnx export
+
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
+ super().__init__()
+ self.nc = nc # number of classes
+ self.no = nc + 5 + 10 # number of outputs per anchor
+
+ self.nl = len(anchors) # number of detection layers
+ self.na = len(anchors[0]) // 2 # number of anchors
+ self.grid = [torch.zeros(1)] * self.nl # init grid
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
+ self.register_buffer("anchors", a) # shape(nl,na,2)
+ self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
+
+ def forward(self, x):
+ z = [] # inference output
+ if self.export:
+ for i in range(self.nl):
+ x[i] = self.m[i](x[i])
+ return x
+ for i in range(self.nl):
+ x[i] = self.m[i](x[i]) # conv
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
+
+ if not self.training: # inference
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
+
+ y = torch.full_like(x[i], 0)
+ y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid()
+ y[..., 5:15] = x[i][..., 5:15]
+
+ y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
+
+ y[..., 5:7] = (
+ y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x1 y1
+ y[..., 7:9] = (
+ y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x2 y2
+ y[..., 9:11] = (
+ y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x3 y3
+ y[..., 11:13] = (
+ y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x4 y4
+ y[..., 13:15] = (
+ y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x5 y5
+
+ z.append(y.view(bs, -1, self.no))
+
+ return x if self.training else (torch.cat(z, 1), x)
+
+ @staticmethod
+ def _make_grid(nx=20, ny=20):
+ # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") # for pytorch>=1.10
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
+
+
+class Model(nn.Module):
+ def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None): # model, input channels, number of classes
+ super().__init__()
+ self.yaml_file = Path(cfg).name
+ with Path(cfg).open(encoding="utf8") as f:
+ self.yaml = yaml.safe_load(f) # model dict
+
+ # Define model
+ ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
+ if nc and nc != self.yaml["nc"]:
+ self.yaml["nc"] = nc # override yaml value
+
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
+ self.names = [str(i) for i in range(self.yaml["nc"])] # default names
+
+ # Build strides, anchors
+ m = self.model[-1] # Detect()
+ if isinstance(m, Detect):
+ s = 128 # 2x min stride
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
+ m.anchors /= m.stride.view(-1, 1, 1)
+ check_anchor_order(m)
+ self.stride = m.stride
+ self._initialize_biases() # only run once
+
+ def forward(self, x):
+ return self.forward_once(x) # single-scale inference, train
+
+ def forward_once(self, x):
+ y = [] # outputs
+ for m in self.model:
+ if m.f != -1: # if not from previous layer
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
+
+ x = m(x) # run
+ y.append(x if m.i in self.save else None) # save output
+
+ return x
+
+ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
+ # https://arxiv.org/abs/1708.02002 section 3.3
+ m = self.model[-1] # Detect() module
+ for mi, s in zip(m.m, m.stride): # from
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+ def _print_biases(self):
+ m = self.model[-1] # Detect() module
+ for mi in m.m: # from
+ b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
+ print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
+
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
+ print("Fusing layers... ")
+ for m in self.model.modules():
+ if isinstance(m, Conv) and hasattr(m, "bn"):
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
+ delattr(m, "bn") # remove batchnorm
+ m.forward = m.fuseforward # update forward
+ elif type(m) is nn.Upsample:
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
+ return self
+
+ def nms(self, mode=True): # add or remove NMS module
+ present = isinstance(self.model[-1], NMS) # last layer is NMS
+ if mode and not present:
+ print("Adding NMS... ")
+ m = NMS() # module
+ m.f = -1 # from
+ m.i = self.model[-1].i + 1 # index
+ self.model.add_module(name=str(m.i), module=m) # add
+ self.eval()
+ elif not mode and present:
+ print("Removing NMS... ")
+ self.model = self.model[:-1] # remove
+ return self
+
+ def autoshape(self): # add autoShape module
+ print("Adding autoShape... ")
+ m = AutoShape(self) # wrap model
+ copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()) # copy attributes
+ return m
+
+
+def parse_model(d, ch): # model_dict, input_channels(3)
+ anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"]
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
+
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
+ for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
+ m = eval(m) if isinstance(m, str) else m # eval strings
+ for j, a in enumerate(args):
+ try:
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
+ except:
+ pass
+
+ n = max(round(n * gd), 1) if n > 1 else n # depth gain
+ if m in [
+ Conv,
+ Bottleneck,
+ SPP,
+ DWConv,
+ MixConv2d,
+ Focus,
+ CrossConv,
+ BottleneckCSP,
+ C3,
+ ShuffleV2Block,
+ StemBlock,
+ ]:
+ c1, c2 = ch[f], args[0]
+
+ c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
+
+ args = [c1, c2, *args[1:]]
+ if m in [BottleneckCSP, C3]:
+ args.insert(2, n)
+ n = 1
+ elif m is nn.BatchNorm2d:
+ args = [ch[f]]
+ elif m is Concat:
+ c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
+ elif m is Detect:
+ args.append([ch[x + 1] for x in f])
+ if isinstance(args[1], int): # number of anchors
+ args[1] = [list(range(args[1] * 2))] * len(f)
+ else:
+ c2 = ch[f]
+
+ m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
+ t = str(m)[8:-2].replace("__main__.", "") # module type
+ np = sum(x.numel() for x in m_.parameters()) # number params
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
+ layers.append(m_)
+ ch.append(c2)
+ return nn.Sequential(*layers), sorted(save)
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/models/yolov5l.yaml b/repositories/CodeFormer/facelib/detection/yolov5face/models/yolov5l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0532b0e22fa7f59349b178146ffddcfdb368aba6
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/yolov5face/models/yolov5l.yaml
@@ -0,0 +1,47 @@
+# parameters
+nc: 1 # number of classes
+depth_multiple: 1.0 # model depth multiple
+width_multiple: 1.0 # layer channel multiple
+
+# anchors
+anchors:
+ - [4,5, 8,10, 13,16] # P3/8
+ - [23,29, 43,55, 73,105] # P4/16
+ - [146,217, 231,300, 335,433] # P5/32
+
+# YOLOv5 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2
+ [-1, 3, C3, [128]],
+ [-1, 1, Conv, [256, 3, 2]], # 2-P3/8
+ [-1, 9, C3, [256]],
+ [-1, 1, Conv, [512, 3, 2]], # 4-P4/16
+ [-1, 9, C3, [512]],
+ [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32
+ [-1, 1, SPP, [1024, [3,5,7]]],
+ [-1, 3, C3, [1024, False]], # 8
+ ]
+
+# YOLOv5 head
+head:
+ [[-1, 1, Conv, [512, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 5], 1, Concat, [1]], # cat backbone P4
+ [-1, 3, C3, [512, False]], # 12
+
+ [-1, 1, Conv, [256, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 3], 1, Concat, [1]], # cat backbone P3
+ [-1, 3, C3, [256, False]], # 16 (P3/8-small)
+
+ [-1, 1, Conv, [256, 3, 2]],
+ [[-1, 13], 1, Concat, [1]], # cat head P4
+ [-1, 3, C3, [512, False]], # 19 (P4/16-medium)
+
+ [-1, 1, Conv, [512, 3, 2]],
+ [[-1, 9], 1, Concat, [1]], # cat head P5
+ [-1, 3, C3, [1024, False]], # 22 (P5/32-large)
+
+ [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
+ ]
\ No newline at end of file
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/models/yolov5n.yaml b/repositories/CodeFormer/facelib/detection/yolov5face/models/yolov5n.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..caba6bed674aa2213b110f19e04eb352ffbeaf1e
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/yolov5face/models/yolov5n.yaml
@@ -0,0 +1,45 @@
+# parameters
+nc: 1 # number of classes
+depth_multiple: 1.0 # model depth multiple
+width_multiple: 1.0 # layer channel multiple
+
+# anchors
+anchors:
+ - [4,5, 8,10, 13,16] # P3/8
+ - [23,29, 43,55, 73,105] # P4/16
+ - [146,217, 231,300, 335,433] # P5/32
+
+# YOLOv5 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4
+ [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8
+ [-1, 3, ShuffleV2Block, [128, 1]], # 2
+ [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16
+ [-1, 7, ShuffleV2Block, [256, 1]], # 4
+ [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32
+ [-1, 3, ShuffleV2Block, [512, 1]], # 6
+ ]
+
+# YOLOv5 head
+head:
+ [[-1, 1, Conv, [128, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 4], 1, Concat, [1]], # cat backbone P4
+ [-1, 1, C3, [128, False]], # 10
+
+ [-1, 1, Conv, [128, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 2], 1, Concat, [1]], # cat backbone P3
+ [-1, 1, C3, [128, False]], # 14 (P3/8-small)
+
+ [-1, 1, Conv, [128, 3, 2]],
+ [[-1, 11], 1, Concat, [1]], # cat head P4
+ [-1, 1, C3, [128, False]], # 17 (P4/16-medium)
+
+ [-1, 1, Conv, [128, 3, 2]],
+ [[-1, 7], 1, Concat, [1]], # cat head P5
+ [-1, 1, C3, [128, False]], # 20 (P5/32-large)
+
+ [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
+ ]
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/utils/__init__.py b/repositories/CodeFormer/facelib/detection/yolov5face/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/utils/autoanchor.py b/repositories/CodeFormer/facelib/detection/yolov5face/utils/autoanchor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4eba3e94888709be7d2a7c7499fbcc1808b4a88
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/yolov5face/utils/autoanchor.py
@@ -0,0 +1,12 @@
+# Auto-anchor utils
+
+
+def check_anchor_order(m):
+ # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
+ a = m.anchor_grid.prod(-1).view(-1) # anchor area
+ da = a[-1] - a[0] # delta a
+ ds = m.stride[-1] - m.stride[0] # delta s
+ if da.sign() != ds.sign(): # same order
+ print("Reversing anchor order")
+ m.anchors[:] = m.anchors.flip(0)
+ m.anchor_grid[:] = m.anchor_grid.flip(0)
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/utils/datasets.py b/repositories/CodeFormer/facelib/detection/yolov5face/utils/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..e672b136f56fd6b05038e24377908361a54fe519
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/yolov5face/utils/datasets.py
@@ -0,0 +1,35 @@
+import cv2
+import numpy as np
+
+
+def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True):
+ # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
+ shape = img.shape[:2] # current shape [height, width]
+ if isinstance(new_shape, int):
+ new_shape = (new_shape, new_shape)
+
+ # Scale ratio (new / old)
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if not scaleup: # only scale down, do not scale up (for better test mAP)
+ r = min(r, 1.0)
+
+ # Compute padding
+ ratio = r, r # width, height ratios
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
+ if auto: # minimum rectangle
+ dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
+ elif scale_fill: # stretch
+ dw, dh = 0.0, 0.0
+ new_unpad = (new_shape[1], new_shape[0])
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
+
+ dw /= 2 # divide padding into 2 sides
+ dh /= 2
+
+ if shape[::-1] != new_unpad: # resize
+ img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
+ return img, ratio, (dw, dh)
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/utils/extract_ckpt.py b/repositories/CodeFormer/facelib/detection/yolov5face/utils/extract_ckpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b8b631348f2d0cdea4e5a3594bb59f3e8f34a0f
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/yolov5face/utils/extract_ckpt.py
@@ -0,0 +1,5 @@
+import torch
+import sys
+sys.path.insert(0,'./facelib/detection/yolov5face')
+model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model']
+torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth')
\ No newline at end of file
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/utils/general.py b/repositories/CodeFormer/facelib/detection/yolov5face/utils/general.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c8e14f56a107ec3a4269c382cfc5168ad780ffc
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/yolov5face/utils/general.py
@@ -0,0 +1,271 @@
+import math
+import time
+
+import numpy as np
+import torch
+import torchvision
+
+
+def check_img_size(img_size, s=32):
+ # Verify img_size is a multiple of stride s
+ new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
+ # if new_size != img_size:
+ # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}")
+ return new_size
+
+
+def make_divisible(x, divisor):
+ # Returns x evenly divisible by divisor
+ return math.ceil(x / divisor) * divisor
+
+
+def xyxy2xywh(x):
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
+ y[:, 2] = x[:, 2] - x[:, 0] # width
+ y[:, 3] = x[:, 3] - x[:, 1] # height
+ return y
+
+
+def xywh2xyxy(x):
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
+ return y
+
+
+def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ coords[:, [0, 2]] -= pad[0] # x padding
+ coords[:, [1, 3]] -= pad[1] # y padding
+ coords[:, :4] /= gain
+ clip_coords(coords, img0_shape)
+ return coords
+
+
+def clip_coords(boxes, img_shape):
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
+
+
+def box_iou(box1, box2):
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
+ """
+ Return intersection-over-union (Jaccard index) of boxes.
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+ Arguments:
+ box1 (Tensor[N, 4])
+ box2 (Tensor[M, 4])
+ Returns:
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
+ IoU values for every element in boxes1 and boxes2
+ """
+
+ def box_area(box):
+ return (box[2] - box[0]) * (box[3] - box[1])
+
+ area1 = box_area(box1.T)
+ area2 = box_area(box2.T)
+
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
+ return inter / (area1[:, None] + area2 - inter)
+
+
+def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
+ """Performs Non-Maximum Suppression (NMS) on inference results
+ Returns:
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
+ """
+
+ nc = prediction.shape[2] - 15 # number of classes
+ xc = prediction[..., 4] > conf_thres # candidates
+
+ # Settings
+ # (pixels) maximum box width and height
+ max_wh = 4096
+ time_limit = 10.0 # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ t = time.time()
+ output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
+ for xi, x in enumerate(prediction): # image index, image inference
+ # Apply constraints
+ x = x[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ label = labels[xi]
+ v = torch.zeros((len(label), nc + 15), device=x.device)
+ v[:, :4] = label[:, 1:5] # box
+ v[:, 4] = 1.0 # conf
+ v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Compute conf
+ x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf
+
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
+ box = xywh2xyxy(x[:, :4])
+
+ # Detections matrix nx6 (xyxy, conf, landmarks, cls)
+ if multi_label:
+ i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1)
+ else: # best class only
+ conf, j = x[:, 15:].max(1, keepdim=True)
+ x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # If none remain process next image
+ n = x.shape[0] # number of boxes
+ if not n:
+ continue
+
+ # Batched NMS
+ c = x[:, 15:16] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+
+ if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if (time.time() - t) > time_limit:
+ break # time limit exceeded
+
+ return output
+
+
+def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
+ """Performs Non-Maximum Suppression (NMS) on inference results
+
+ Returns:
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
+ """
+
+ nc = prediction.shape[2] - 5 # number of classes
+ xc = prediction[..., 4] > conf_thres # candidates
+
+ # Settings
+ # (pixels) maximum box width and height
+ max_wh = 4096
+ time_limit = 10.0 # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ t = time.time()
+ output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
+ for xi, x in enumerate(prediction): # image index, image inference
+ x = x[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ label_id = labels[xi]
+ v = torch.zeros((len(label_id), nc + 5), device=x.device)
+ v[:, :4] = label_id[:, 1:5] # box
+ v[:, 4] = 1.0 # conf
+ v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Compute conf
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
+
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
+ box = xywh2xyxy(x[:, :4])
+
+ # Detections matrix nx6 (xyxy, conf, cls)
+ if multi_label:
+ i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
+ else: # best class only
+ conf, j = x[:, 5:].max(1, keepdim=True)
+ x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # Check shape
+ n = x.shape[0] # number of boxes
+ if not n: # no boxes
+ continue
+
+ x = x[x[:, 4].argsort(descending=True)] # sort by confidence
+
+ # Batched NMS
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+ if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if (time.time() - t) > time_limit:
+ print(f"WARNING: NMS time limit {time_limit}s exceeded")
+ break # time limit exceeded
+
+ return output
+
+
+def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding
+ coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding
+ coords[:, :10] /= gain
+ coords[:, 0].clamp_(0, img0_shape[1]) # x1
+ coords[:, 1].clamp_(0, img0_shape[0]) # y1
+ coords[:, 2].clamp_(0, img0_shape[1]) # x2
+ coords[:, 3].clamp_(0, img0_shape[0]) # y2
+ coords[:, 4].clamp_(0, img0_shape[1]) # x3
+ coords[:, 5].clamp_(0, img0_shape[0]) # y3
+ coords[:, 6].clamp_(0, img0_shape[1]) # x4
+ coords[:, 7].clamp_(0, img0_shape[0]) # y4
+ coords[:, 8].clamp_(0, img0_shape[1]) # x5
+ coords[:, 9].clamp_(0, img0_shape[0]) # y5
+ return coords
diff --git a/repositories/CodeFormer/facelib/detection/yolov5face/utils/torch_utils.py b/repositories/CodeFormer/facelib/detection/yolov5face/utils/torch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af2d06587b2d07b2eab199a8484380fde1de5c3c
--- /dev/null
+++ b/repositories/CodeFormer/facelib/detection/yolov5face/utils/torch_utils.py
@@ -0,0 +1,40 @@
+import torch
+from torch import nn
+
+
+def fuse_conv_and_bn(conv, bn):
+ # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
+ fusedconv = (
+ nn.Conv2d(
+ conv.in_channels,
+ conv.out_channels,
+ kernel_size=conv.kernel_size,
+ stride=conv.stride,
+ padding=conv.padding,
+ groups=conv.groups,
+ bias=True,
+ )
+ .requires_grad_(False)
+ .to(conv.weight.device)
+ )
+
+ # prepare filters
+ w_conv = conv.weight.clone().view(conv.out_channels, -1)
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
+
+ # prepare spatial bias
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
+ fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
+
+ return fusedconv
+
+
+def copy_attr(a, b, include=(), exclude=()):
+ # Copy attributes from b to a, options to only include [...] and to exclude [...]
+ for k, v in b.__dict__.items():
+ if (include and k not in include) or k.startswith("_") or k in exclude:
+ continue
+
+ setattr(a, k, v)
diff --git a/repositories/CodeFormer/facelib/parsing/__init__.py b/repositories/CodeFormer/facelib/parsing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72656e4b5f61df8cd0838588b0c6488fcc886e16
--- /dev/null
+++ b/repositories/CodeFormer/facelib/parsing/__init__.py
@@ -0,0 +1,23 @@
+import torch
+
+from facelib.utils import load_file_from_url
+from .bisenet import BiSeNet
+from .parsenet import ParseNet
+
+
+def init_parsing_model(model_name='bisenet', half=False, device='cuda'):
+ if model_name == 'bisenet':
+ model = BiSeNet(num_class=19)
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth'
+ elif model_name == 'parsenet':
+ model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
+ load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+ model.load_state_dict(load_net, strict=True)
+ model.eval()
+ model = model.to(device)
+ return model
diff --git a/repositories/CodeFormer/facelib/parsing/bisenet.py b/repositories/CodeFormer/facelib/parsing/bisenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3898cab76ae5876459cd4899c54cafa14234971d
--- /dev/null
+++ b/repositories/CodeFormer/facelib/parsing/bisenet.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .resnet import ResNet18
+
+
+class ConvBNReLU(nn.Module):
+
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
+ super(ConvBNReLU, self).__init__()
+ self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
+ self.bn = nn.BatchNorm2d(out_chan)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = F.relu(self.bn(x))
+ return x
+
+
+class BiSeNetOutput(nn.Module):
+
+ def __init__(self, in_chan, mid_chan, num_class):
+ super(BiSeNetOutput, self).__init__()
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
+ self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
+
+ def forward(self, x):
+ feat = self.conv(x)
+ out = self.conv_out(feat)
+ return out, feat
+
+
+class AttentionRefinementModule(nn.Module):
+
+ def __init__(self, in_chan, out_chan):
+ super(AttentionRefinementModule, self).__init__()
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
+ self.bn_atten = nn.BatchNorm2d(out_chan)
+ self.sigmoid_atten = nn.Sigmoid()
+
+ def forward(self, x):
+ feat = self.conv(x)
+ atten = F.avg_pool2d(feat, feat.size()[2:])
+ atten = self.conv_atten(atten)
+ atten = self.bn_atten(atten)
+ atten = self.sigmoid_atten(atten)
+ out = torch.mul(feat, atten)
+ return out
+
+
+class ContextPath(nn.Module):
+
+ def __init__(self):
+ super(ContextPath, self).__init__()
+ self.resnet = ResNet18()
+ self.arm16 = AttentionRefinementModule(256, 128)
+ self.arm32 = AttentionRefinementModule(512, 128)
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
+
+ def forward(self, x):
+ feat8, feat16, feat32 = self.resnet(x)
+ h8, w8 = feat8.size()[2:]
+ h16, w16 = feat16.size()[2:]
+ h32, w32 = feat32.size()[2:]
+
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
+ avg = self.conv_avg(avg)
+ avg_up = F.interpolate(avg, (h32, w32), mode='nearest')
+
+ feat32_arm = self.arm32(feat32)
+ feat32_sum = feat32_arm + avg_up
+ feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest')
+ feat32_up = self.conv_head32(feat32_up)
+
+ feat16_arm = self.arm16(feat16)
+ feat16_sum = feat16_arm + feat32_up
+ feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest')
+ feat16_up = self.conv_head16(feat16_up)
+
+ return feat8, feat16_up, feat32_up # x8, x8, x16
+
+
+class FeatureFusionModule(nn.Module):
+
+ def __init__(self, in_chan, out_chan):
+ super(FeatureFusionModule, self).__init__()
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
+ self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
+ self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
+ self.relu = nn.ReLU(inplace=True)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, fsp, fcp):
+ fcat = torch.cat([fsp, fcp], dim=1)
+ feat = self.convblk(fcat)
+ atten = F.avg_pool2d(feat, feat.size()[2:])
+ atten = self.conv1(atten)
+ atten = self.relu(atten)
+ atten = self.conv2(atten)
+ atten = self.sigmoid(atten)
+ feat_atten = torch.mul(feat, atten)
+ feat_out = feat_atten + feat
+ return feat_out
+
+
+class BiSeNet(nn.Module):
+
+ def __init__(self, num_class):
+ super(BiSeNet, self).__init__()
+ self.cp = ContextPath()
+ self.ffm = FeatureFusionModule(256, 256)
+ self.conv_out = BiSeNetOutput(256, 256, num_class)
+ self.conv_out16 = BiSeNetOutput(128, 64, num_class)
+ self.conv_out32 = BiSeNetOutput(128, 64, num_class)
+
+ def forward(self, x, return_feat=False):
+ h, w = x.size()[2:]
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
+ feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
+
+ out, feat = self.conv_out(feat_fuse)
+ out16, feat16 = self.conv_out16(feat_cp8)
+ out32, feat32 = self.conv_out32(feat_cp16)
+
+ out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
+ out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True)
+ out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True)
+
+ if return_feat:
+ feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True)
+ feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True)
+ feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True)
+ return out, out16, out32, feat, feat16, feat32
+ else:
+ return out, out16, out32
diff --git a/repositories/CodeFormer/facelib/parsing/parsenet.py b/repositories/CodeFormer/facelib/parsing/parsenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e178ebe43a1ef666aaea0bc0faf629485c22a24f
--- /dev/null
+++ b/repositories/CodeFormer/facelib/parsing/parsenet.py
@@ -0,0 +1,194 @@
+"""Modified from https://github.com/chaofengc/PSFRGAN
+"""
+import numpy as np
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+class NormLayer(nn.Module):
+ """Normalization Layers.
+
+ Args:
+ channels: input channels, for batch norm and instance norm.
+ input_size: input shape without batch size, for layer norm.
+ """
+
+ def __init__(self, channels, normalize_shape=None, norm_type='bn'):
+ super(NormLayer, self).__init__()
+ norm_type = norm_type.lower()
+ self.norm_type = norm_type
+ if norm_type == 'bn':
+ self.norm = nn.BatchNorm2d(channels, affine=True)
+ elif norm_type == 'in':
+ self.norm = nn.InstanceNorm2d(channels, affine=False)
+ elif norm_type == 'gn':
+ self.norm = nn.GroupNorm(32, channels, affine=True)
+ elif norm_type == 'pixel':
+ self.norm = lambda x: F.normalize(x, p=2, dim=1)
+ elif norm_type == 'layer':
+ self.norm = nn.LayerNorm(normalize_shape)
+ elif norm_type == 'none':
+ self.norm = lambda x: x * 1.0
+ else:
+ assert 1 == 0, f'Norm type {norm_type} not support.'
+
+ def forward(self, x, ref=None):
+ if self.norm_type == 'spade':
+ return self.norm(x, ref)
+ else:
+ return self.norm(x)
+
+
+class ReluLayer(nn.Module):
+ """Relu Layer.
+
+ Args:
+ relu type: type of relu layer, candidates are
+ - ReLU
+ - LeakyReLU: default relu slope 0.2
+ - PRelu
+ - SELU
+ - none: direct pass
+ """
+
+ def __init__(self, channels, relu_type='relu'):
+ super(ReluLayer, self).__init__()
+ relu_type = relu_type.lower()
+ if relu_type == 'relu':
+ self.func = nn.ReLU(True)
+ elif relu_type == 'leakyrelu':
+ self.func = nn.LeakyReLU(0.2, inplace=True)
+ elif relu_type == 'prelu':
+ self.func = nn.PReLU(channels)
+ elif relu_type == 'selu':
+ self.func = nn.SELU(True)
+ elif relu_type == 'none':
+ self.func = lambda x: x * 1.0
+ else:
+ assert 1 == 0, f'Relu type {relu_type} not support.'
+
+ def forward(self, x):
+ return self.func(x)
+
+
+class ConvLayer(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ scale='none',
+ norm_type='none',
+ relu_type='none',
+ use_pad=True,
+ bias=True):
+ super(ConvLayer, self).__init__()
+ self.use_pad = use_pad
+ self.norm_type = norm_type
+ if norm_type in ['bn']:
+ bias = False
+
+ stride = 2 if scale == 'down' else 1
+
+ self.scale_func = lambda x: x
+ if scale == 'up':
+ self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
+
+ self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
+ self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
+
+ self.relu = ReluLayer(out_channels, relu_type)
+ self.norm = NormLayer(out_channels, norm_type=norm_type)
+
+ def forward(self, x):
+ out = self.scale_func(x)
+ if self.use_pad:
+ out = self.reflection_pad(out)
+ out = self.conv2d(out)
+ out = self.norm(out)
+ out = self.relu(out)
+ return out
+
+
+class ResidualBlock(nn.Module):
+ """
+ Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
+ """
+
+ def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
+ super(ResidualBlock, self).__init__()
+
+ if scale == 'none' and c_in == c_out:
+ self.shortcut_func = lambda x: x
+ else:
+ self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
+
+ scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
+ scale_conf = scale_config_dict[scale]
+
+ self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
+ self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
+
+ def forward(self, x):
+ identity = self.shortcut_func(x)
+
+ res = self.conv1(x)
+ res = self.conv2(res)
+ return identity + res
+
+
+class ParseNet(nn.Module):
+
+ def __init__(self,
+ in_size=128,
+ out_size=128,
+ min_feat_size=32,
+ base_ch=64,
+ parsing_ch=19,
+ res_depth=10,
+ relu_type='LeakyReLU',
+ norm_type='bn',
+ ch_range=[32, 256]):
+ super().__init__()
+ self.res_depth = res_depth
+ act_args = {'norm_type': norm_type, 'relu_type': relu_type}
+ min_ch, max_ch = ch_range
+
+ ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
+ min_feat_size = min(in_size, min_feat_size)
+
+ down_steps = int(np.log2(in_size // min_feat_size))
+ up_steps = int(np.log2(out_size // min_feat_size))
+
+ # =============== define encoder-body-decoder ====================
+ self.encoder = []
+ self.encoder.append(ConvLayer(3, base_ch, 3, 1))
+ head_ch = base_ch
+ for i in range(down_steps):
+ cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
+ self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
+ head_ch = head_ch * 2
+
+ self.body = []
+ for i in range(res_depth):
+ self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
+
+ self.decoder = []
+ for i in range(up_steps):
+ cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
+ self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
+ head_ch = head_ch // 2
+
+ self.encoder = nn.Sequential(*self.encoder)
+ self.body = nn.Sequential(*self.body)
+ self.decoder = nn.Sequential(*self.decoder)
+ self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
+ self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
+
+ def forward(self, x):
+ feat = self.encoder(x)
+ x = feat + self.body(feat)
+ x = self.decoder(x)
+ out_img = self.out_img_conv(x)
+ out_mask = self.out_mask_conv(x)
+ return out_mask, out_img
diff --git a/repositories/CodeFormer/facelib/parsing/resnet.py b/repositories/CodeFormer/facelib/parsing/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..fec8e82cf64469fb51be21ad5130217052addbda
--- /dev/null
+++ b/repositories/CodeFormer/facelib/parsing/resnet.py
@@ -0,0 +1,69 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+
+ def __init__(self, in_chan, out_chan, stride=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
+ self.bn1 = nn.BatchNorm2d(out_chan)
+ self.conv2 = conv3x3(out_chan, out_chan)
+ self.bn2 = nn.BatchNorm2d(out_chan)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ if in_chan != out_chan or stride != 1:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(out_chan),
+ )
+
+ def forward(self, x):
+ residual = self.conv1(x)
+ residual = F.relu(self.bn1(residual))
+ residual = self.conv2(residual)
+ residual = self.bn2(residual)
+
+ shortcut = x
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+
+ out = shortcut + residual
+ out = self.relu(out)
+ return out
+
+
+def create_layer_basic(in_chan, out_chan, bnum, stride=1):
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
+ for i in range(bnum - 1):
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
+ return nn.Sequential(*layers)
+
+
+class ResNet18(nn.Module):
+
+ def __init__(self):
+ super(ResNet18, self).__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = F.relu(self.bn1(x))
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ feat8 = self.layer2(x) # 1/8
+ feat16 = self.layer3(feat8) # 1/16
+ feat32 = self.layer4(feat16) # 1/32
+ return feat8, feat16, feat32
diff --git a/repositories/CodeFormer/facelib/utils/__init__.py b/repositories/CodeFormer/facelib/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f03b1c2bafcd7759cb7e8722a0c6715f201a46dc
--- /dev/null
+++ b/repositories/CodeFormer/facelib/utils/__init__.py
@@ -0,0 +1,7 @@
+from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
+from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir
+
+__all__ = [
+ 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url',
+ 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir'
+]
diff --git a/repositories/CodeFormer/facelib/utils/face_restoration_helper.py b/repositories/CodeFormer/facelib/utils/face_restoration_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b7644ccd3d9978aea7997a76f7c6fdb0ccde8b1
--- /dev/null
+++ b/repositories/CodeFormer/facelib/utils/face_restoration_helper.py
@@ -0,0 +1,455 @@
+import cv2
+import numpy as np
+import os
+import torch
+from torchvision.transforms.functional import normalize
+
+from facelib.detection import init_detection_model
+from facelib.parsing import init_parsing_model
+from facelib.utils.misc import img2tensor, imwrite
+
+
+def get_largest_face(det_faces, h, w):
+
+ def get_location(val, length):
+ if val < 0:
+ return 0
+ elif val > length:
+ return length
+ else:
+ return val
+
+ face_areas = []
+ for det_face in det_faces:
+ left = get_location(det_face[0], w)
+ right = get_location(det_face[2], w)
+ top = get_location(det_face[1], h)
+ bottom = get_location(det_face[3], h)
+ face_area = (right - left) * (bottom - top)
+ face_areas.append(face_area)
+ largest_idx = face_areas.index(max(face_areas))
+ return det_faces[largest_idx], largest_idx
+
+
+def get_center_face(det_faces, h=0, w=0, center=None):
+ if center is not None:
+ center = np.array(center)
+ else:
+ center = np.array([w / 2, h / 2])
+ center_dist = []
+ for det_face in det_faces:
+ face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
+ dist = np.linalg.norm(face_center - center)
+ center_dist.append(dist)
+ center_idx = center_dist.index(min(center_dist))
+ return det_faces[center_idx], center_idx
+
+
+class FaceRestoreHelper(object):
+ """Helper for the face restoration pipeline (base class)."""
+
+ def __init__(self,
+ upscale_factor,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ template_3points=False,
+ pad_blur=False,
+ use_parse=False,
+ device=None):
+ self.template_3points = template_3points # improve robustness
+ self.upscale_factor = upscale_factor
+ # the cropped face ratio based on the square face
+ self.crop_ratio = crop_ratio # (h, w)
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
+ self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
+
+ if self.template_3points:
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
+ else:
+ # standard 5 landmarks for FFHQ faces with 512 x 512
+ # facexlib
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
+ [201.26117, 371.41043], [313.08905, 371.15118]])
+
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
+
+
+ self.face_template = self.face_template * (face_size / 512.0)
+ if self.crop_ratio[0] > 1:
+ self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
+ if self.crop_ratio[1] > 1:
+ self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
+ self.save_ext = save_ext
+ self.pad_blur = pad_blur
+ if self.pad_blur is True:
+ self.template_3points = False
+
+ self.all_landmarks_5 = []
+ self.det_faces = []
+ self.affine_matrices = []
+ self.inverse_affine_matrices = []
+ self.cropped_faces = []
+ self.restored_faces = []
+ self.pad_input_imgs = []
+
+ if device is None:
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ else:
+ self.device = device
+
+ # init face detection model
+ self.face_det = init_detection_model(det_model, half=False, device=self.device)
+
+ # init face parsing model
+ self.use_parse = use_parse
+ self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
+
+ def set_upscale_factor(self, upscale_factor):
+ self.upscale_factor = upscale_factor
+
+ def read_image(self, img):
+ """img can be image path or cv2 loaded image."""
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
+ if isinstance(img, str):
+ img = cv2.imread(img)
+
+ if np.max(img) > 256: # 16-bit image
+ img = img / 65535 * 255
+ if len(img.shape) == 2: # gray image
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif img.shape[2] == 4: # BGRA image with alpha channel
+ img = img[:, :, 0:3]
+
+ self.input_img = img
+
+ if min(self.input_img.shape[:2])<512:
+ f = 512.0/min(self.input_img.shape[:2])
+ self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
+
+ def get_face_landmarks_5(self,
+ only_keep_largest=False,
+ only_center_face=False,
+ resize=None,
+ blur_ratio=0.01,
+ eye_dist_threshold=None):
+ if resize is None:
+ scale = 1
+ input_img = self.input_img
+ else:
+ h, w = self.input_img.shape[0:2]
+ scale = resize / min(h, w)
+ scale = max(1, scale) # always scale up
+ h, w = int(h * scale), int(w * scale)
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
+ input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
+
+ with torch.no_grad():
+ bboxes = self.face_det.detect_faces(input_img)
+
+ if bboxes is None or bboxes.shape[0] == 0:
+ return 0
+ else:
+ bboxes = bboxes / scale
+
+ for bbox in bboxes:
+ # remove faces with too small eye distance: side faces or too small faces
+ eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
+ continue
+
+ if self.template_3points:
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
+ else:
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
+ self.all_landmarks_5.append(landmark)
+ self.det_faces.append(bbox[0:5])
+
+ if len(self.det_faces) == 0:
+ return 0
+ if only_keep_largest:
+ h, w, _ = self.input_img.shape
+ self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
+ elif only_center_face:
+ h, w, _ = self.input_img.shape
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
+
+ # pad blurry images
+ if self.pad_blur:
+ self.pad_input_imgs = []
+ for landmarks in self.all_landmarks_5:
+ # get landmarks
+ eye_left = landmarks[0, :]
+ eye_right = landmarks[1, :]
+ eye_avg = (eye_left + eye_right) * 0.5
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
+ eye_to_eye = eye_right - eye_left
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Get the oriented crop rectangle
+ # x: half width of the oriented crop rectangle
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+ # norm with the hypotenuse: get the direction
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
+ rect_scale = 1.5
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+ # y: half height of the oriented crop rectangle
+ y = np.flipud(x) * [-1, 1]
+
+ # c: center
+ c = eye_avg + eye_to_mouth * 0.1
+ # quad: (left_top, left_bottom, right_bottom, right_top)
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ # qsize: side length of the square
+ qsize = np.hypot(*x) * 2
+ border = max(int(np.rint(qsize * 0.1)), 3)
+
+ # get pad
+ # pad: (width_left, height_top, width_right, height_bottom)
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = [
+ max(-pad[0] + border, 1),
+ max(-pad[1] + border, 1),
+ max(pad[2] - self.input_img.shape[0] + border, 1),
+ max(pad[3] - self.input_img.shape[1] + border, 1)
+ ]
+
+ if max(pad) > 1:
+ # pad image
+ pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ # modify landmark coords
+ landmarks[:, 0] += pad[0]
+ landmarks[:, 1] += pad[1]
+ # blur pad images
+ h, w, _ = pad_img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = int(qsize * blur_ratio)
+ if blur % 2 == 0:
+ blur += 1
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
+
+ pad_img = pad_img.astype('float32')
+ pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
+ self.pad_input_imgs.append(pad_img)
+ else:
+ self.pad_input_imgs.append(np.copy(self.input_img))
+
+ return len(self.all_landmarks_5)
+
+ def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
+ """Align and warp faces with face template.
+ """
+ if self.pad_blur:
+ assert len(self.pad_input_imgs) == len(
+ self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
+ for idx, landmark in enumerate(self.all_landmarks_5):
+ # use 5 landmarks to get affine matrix
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
+ self.affine_matrices.append(affine_matrix)
+ # warp and crop faces
+ if border_mode == 'constant':
+ border_mode = cv2.BORDER_CONSTANT
+ elif border_mode == 'reflect101':
+ border_mode = cv2.BORDER_REFLECT101
+ elif border_mode == 'reflect':
+ border_mode = cv2.BORDER_REFLECT
+ if self.pad_blur:
+ input_img = self.pad_input_imgs[idx]
+ else:
+ input_img = self.input_img
+ cropped_face = cv2.warpAffine(
+ input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
+ self.cropped_faces.append(cropped_face)
+ # save the cropped face
+ if save_cropped_path is not None:
+ path = os.path.splitext(save_cropped_path)[0]
+ save_path = f'{path}_{idx:02d}.{self.save_ext}'
+ imwrite(cropped_face, save_path)
+
+ def get_inverse_affine(self, save_inverse_affine_path=None):
+ """Get inverse affine matrix."""
+ for idx, affine_matrix in enumerate(self.affine_matrices):
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
+ inverse_affine *= self.upscale_factor
+ self.inverse_affine_matrices.append(inverse_affine)
+ # save inverse affine matrices
+ if save_inverse_affine_path is not None:
+ path, _ = os.path.splitext(save_inverse_affine_path)
+ save_path = f'{path}_{idx:02d}.pth'
+ torch.save(inverse_affine, save_path)
+
+
+ def add_restored_face(self, face):
+ self.restored_faces.append(face)
+
+
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
+ h, w, _ = self.input_img.shape
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
+
+ if upsample_img is None:
+ # simply resize the background
+ # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+ upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
+ else:
+ upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+
+ assert len(self.restored_faces) == len(
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
+
+ inv_mask_borders = []
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
+ if face_upsampler is not None:
+ restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
+ inverse_affine /= self.upscale_factor
+ inverse_affine[:, 2] *= self.upscale_factor
+ face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
+ else:
+ # Add an offset to inverse affine matrix, for more precise back alignment
+ if self.upscale_factor > 1:
+ extra_offset = 0.5 * self.upscale_factor
+ else:
+ extra_offset = 0
+ inverse_affine[:, 2] += extra_offset
+ face_size = self.face_size
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
+
+ # if draw_box or not self.use_parse: # use square parse maps
+ # mask = np.ones(face_size, dtype=np.float32)
+ # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # # remove the black borders
+ # inv_mask_erosion = cv2.erode(
+ # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ # total_face_area = np.sum(inv_mask_erosion) # // 3
+ # # add border
+ # if draw_box:
+ # h, w = face_size
+ # mask_border = np.ones((h, w, 3), dtype=np.float32)
+ # border = int(1400/np.sqrt(total_face_area))
+ # mask_border[border:h-border, border:w-border,:] = 0
+ # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+ # inv_mask_borders.append(inv_mask_border)
+ # if not self.use_parse:
+ # # compute the fusion edge based on the area of face
+ # w_edge = int(total_face_area**0.5) // 20
+ # erosion_radius = w_edge * 2
+ # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ # blur_size = w_edge * 2
+ # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ # if len(upsample_img.shape) == 2: # upsample_img is gray image
+ # upsample_img = upsample_img[:, :, None]
+ # inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # always use square mask
+ mask = np.ones(face_size, dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) # // 3
+ # add border
+ if draw_box:
+ h, w = face_size
+ mask_border = np.ones((h, w, 3), dtype=np.float32)
+ border = int(1400/np.sqrt(total_face_area))
+ mask_border[border:h-border, border:w-border,:] = 0
+ inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+ inv_mask_borders.append(inv_mask_border)
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area**0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
+ upsample_img = upsample_img[:, :, None]
+ inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # parse mask
+ if self.use_parse:
+ # inference
+ face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
+ normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ face_input = torch.unsqueeze(face_input, 0).to(self.device)
+ with torch.no_grad():
+ out = self.face_parse(face_input)[0]
+ out = out.argmax(dim=1).squeeze().cpu().numpy()
+
+ parse_mask = np.zeros(out.shape)
+ MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
+ for idx, color in enumerate(MASK_COLORMAP):
+ parse_mask[out == idx] = color
+ # blur the mask
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ # remove the black borders
+ thres = 10
+ parse_mask[:thres, :] = 0
+ parse_mask[-thres:, :] = 0
+ parse_mask[:, :thres] = 0
+ parse_mask[:, -thres:] = 0
+ parse_mask = parse_mask / 255.
+
+ parse_mask = cv2.resize(parse_mask, face_size)
+ parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
+ inv_soft_parse_mask = parse_mask[:, :, None]
+ # pasted_face = inv_restored
+ fuse_mask = (inv_soft_parse_mask 256: # 16-bit image
+ upsample_img = upsample_img.astype(np.uint16)
+ else:
+ upsample_img = upsample_img.astype(np.uint8)
+
+ # draw bounding box
+ if draw_box:
+ # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
+ img_color = np.ones([*upsample_img.shape], dtype=np.float32)
+ img_color[:,:,0] = 0
+ img_color[:,:,1] = 255
+ img_color[:,:,2] = 0
+ for inv_mask_border in inv_mask_borders:
+ upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
+ # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
+
+ if save_path is not None:
+ path = os.path.splitext(save_path)[0]
+ save_path = f'{path}.{self.save_ext}'
+ imwrite(upsample_img, save_path)
+ return upsample_img
+
+ def clean_all(self):
+ self.all_landmarks_5 = []
+ self.restored_faces = []
+ self.affine_matrices = []
+ self.cropped_faces = []
+ self.inverse_affine_matrices = []
+ self.det_faces = []
+ self.pad_input_imgs = []
\ No newline at end of file
diff --git a/repositories/CodeFormer/facelib/utils/face_utils.py b/repositories/CodeFormer/facelib/utils/face_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1474a2a4419b6b62fab8a919ef805b802556464
--- /dev/null
+++ b/repositories/CodeFormer/facelib/utils/face_utils.py
@@ -0,0 +1,248 @@
+import cv2
+import numpy as np
+import torch
+
+
+def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
+ left, top, right, bot = bbox
+ width = right - left
+ height = bot - top
+
+ if preserve_aspect:
+ width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
+ height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
+ else:
+ width_increase = height_increase = increase_area
+ left = int(left - width_increase * width)
+ top = int(top - height_increase * height)
+ right = int(right + width_increase * width)
+ bot = int(bot + height_increase * height)
+ return (left, top, right, bot)
+
+
+def get_valid_bboxes(bboxes, h, w):
+ left = max(bboxes[0], 0)
+ top = max(bboxes[1], 0)
+ right = min(bboxes[2], w)
+ bottom = min(bboxes[3], h)
+ return (left, top, right, bottom)
+
+
+def align_crop_face_landmarks(img,
+ landmarks,
+ output_size,
+ transform_size=None,
+ enable_padding=True,
+ return_inverse_affine=False,
+ shrink_ratio=(1, 1)):
+ """Align and crop face with landmarks.
+
+ The output_size and transform_size are based on width. The height is
+ adjusted based on shrink_ratio_h/shring_ration_w.
+
+ Modified from:
+ https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
+
+ Args:
+ img (Numpy array): Input image.
+ landmarks (Numpy array): 5 or 68 or 98 landmarks.
+ output_size (int): Output face size.
+ transform_size (ing): Transform size. Usually the four time of
+ output_size.
+ enable_padding (float): Default: True.
+ shrink_ratio (float | tuple[float] | list[float]): Shring the whole
+ face for height and width (crop larger area). Default: (1, 1).
+
+ Returns:
+ (Numpy array): Cropped face.
+ """
+ lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5
+
+ if isinstance(shrink_ratio, (float, int)):
+ shrink_ratio = (shrink_ratio, shrink_ratio)
+ if transform_size is None:
+ transform_size = output_size * 4
+
+ # Parse landmarks
+ lm = np.array(landmarks)
+ if lm.shape[0] == 5 and lm_type == 'retinaface_5':
+ eye_left = lm[0]
+ eye_right = lm[1]
+ mouth_avg = (lm[3] + lm[4]) * 0.5
+ elif lm.shape[0] == 5 and lm_type == 'dlib_5':
+ lm_eye_left = lm[2:4]
+ lm_eye_right = lm[0:2]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = lm[4]
+ elif lm.shape[0] == 68:
+ lm_eye_left = lm[36:42]
+ lm_eye_right = lm[42:48]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = (lm[48] + lm[54]) * 0.5
+ elif lm.shape[0] == 98:
+ lm_eye_left = lm[60:68]
+ lm_eye_right = lm[68:76]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = (lm[76] + lm[82]) * 0.5
+
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Get the oriented crop rectangle
+ # x: half width of the oriented crop rectangle
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+ # norm with the hypotenuse: get the direction
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
+ rect_scale = 1 # TODO: you can edit it to get larger rect
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+ # y: half height of the oriented crop rectangle
+ y = np.flipud(x) * [-1, 1]
+
+ x *= shrink_ratio[1] # width
+ y *= shrink_ratio[0] # height
+
+ # c: center
+ c = eye_avg + eye_to_mouth * 0.1
+ # quad: (left_top, left_bottom, right_bottom, right_top)
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ # qsize: side length of the square
+ qsize = np.hypot(*x) * 2
+
+ quad_ori = np.copy(quad)
+ # Shrink, for large face
+ # TODO: do we really need shrink
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ h, w = img.shape[0:2]
+ rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
+ img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop
+ h, w = img.shape[0:2]
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
+ if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
+ img = img[crop[1]:crop[3], crop[0]:crop[2], :]
+ quad -= crop[0:2]
+
+ # Pad
+ # pad: (width_left, height_top, width_right, height_bottom)
+ h, w = img.shape[0:2]
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ h, w = img.shape[0:2]
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = int(qsize * 0.02)
+ if blur % 2 == 0:
+ blur += 1
+ blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
+
+ img = img.astype('float32')
+ img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = np.clip(img, 0, 255) # float32, [0, 255]
+ quad += pad[:2]
+
+ # Transform use cv2
+ h_ratio = shrink_ratio[0] / shrink_ratio[1]
+ dst_h, dst_w = int(transform_size * h_ratio), transform_size
+ template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
+ cropped_face = cv2.warpAffine(
+ img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray
+
+ if output_size < transform_size:
+ cropped_face = cv2.resize(
+ cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
+
+ if return_inverse_affine:
+ dst_h, dst_w = int(output_size * h_ratio), output_size
+ template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(
+ quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
+ else:
+ inverse_affine = None
+ return cropped_face, inverse_affine
+
+
+def paste_face_back(img, face, inverse_affine):
+ h, w = img.shape[0:2]
+ face_h, face_w = face.shape[0:2]
+ inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
+ mask = np.ones((face_h, face_w, 3), dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
+ inv_restored_remove_border = inv_mask_erosion * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) // 3
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area**0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
+ # float32, [0, 255]
+ return img
+
+
+if __name__ == '__main__':
+ import os
+
+ from facelib.detection import init_detection_model
+ from facelib.utils.face_restoration_helper import get_largest_face
+
+ img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png'
+ img_name = os.splitext(os.path.basename(img_path))[0]
+
+ # initialize model
+ det_net = init_detection_model('retinaface_resnet50', half=False)
+ img_ori = cv2.imread(img_path)
+ h, w = img_ori.shape[0:2]
+ # if larger than 800, scale it
+ scale = max(h / 800, w / 800)
+ if scale > 1:
+ img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR)
+
+ with torch.no_grad():
+ bboxes = det_net.detect_faces(img, 0.97)
+ if scale > 1:
+ bboxes *= scale # the score is incorrect
+ bboxes = get_largest_face(bboxes, h, w)[0]
+
+ landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)])
+
+ cropped_face, inverse_affine = align_crop_face_landmarks(
+ img_ori,
+ landmarks,
+ output_size=512,
+ transform_size=None,
+ enable_padding=True,
+ return_inverse_affine=True,
+ shrink_ratio=(1, 1))
+
+ cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face)
+ img = paste_face_back(img_ori, cropped_face, inverse_affine)
+ cv2.imwrite(f'tmp/{img_name}_back.png', img)
diff --git a/repositories/CodeFormer/facelib/utils/misc.py b/repositories/CodeFormer/facelib/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..0918283c297a927fc0216670bbe78079087c6312
--- /dev/null
+++ b/repositories/CodeFormer/facelib/utils/misc.py
@@ -0,0 +1,141 @@
+import cv2
+import os
+import os.path as osp
+import torch
+from torch.hub import download_url_to_file, get_dir
+from urllib.parse import urlparse
+# from basicsr.utils.download_util import download_file_from_google_drive
+import gdown
+
+
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+
+def download_pretrained_models(file_ids, save_path_root):
+ os.makedirs(save_path_root, exist_ok=True)
+
+ for file_name, file_id in file_ids.items():
+ file_url = 'https://drive.google.com/uc?id='+file_id
+ save_path = osp.abspath(osp.join(save_path_root, file_name))
+ if osp.exists(save_path):
+ user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
+ if user_response.lower() == 'y':
+ print(f'Covering {file_name} to {save_path}')
+ gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+ elif user_response.lower() == 'n':
+ print(f'Skipping {file_name}')
+ else:
+ raise ValueError('Wrong input. Only accepts Y/N.')
+ else:
+ print(f'Downloading {file_name} to {save_path}')
+ gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+ """
+ if model_dir is None:
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
diff --git a/repositories/CodeFormer/inference_codeformer.py b/repositories/CodeFormer/inference_codeformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdfe8b301cc7c20c2fb653618e379d243603a108
--- /dev/null
+++ b/repositories/CodeFormer/inference_codeformer.py
@@ -0,0 +1,189 @@
+# Modified by Shangchen Zhou from: https://github.com/TencentARC/GFPGAN/blob/master/inference_gfpgan.py
+import os
+import cv2
+import argparse
+import glob
+import torch
+from torchvision.transforms.functional import normalize
+from basicsr.utils import imwrite, img2tensor, tensor2img
+from basicsr.utils.download_util import load_file_from_url
+from facelib.utils.face_restoration_helper import FaceRestoreHelper
+import torch.nn.functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+pretrain_model_url = {
+ 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
+}
+
+def set_realesrgan():
+ if not torch.cuda.is_available(): # CPU
+ import warnings
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
+ 'If you really want to use it, please modify the corresponding codes.',
+ category=RuntimeWarning)
+ bg_upsampler = None
+ else:
+ from basicsr.archs.rrdbnet_arch import RRDBNet
+ from basicsr.utils.realesrgan_utils import RealESRGANer
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+ bg_upsampler = RealESRGANer(
+ scale=2,
+ model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
+ model=model,
+ tile=args.bg_tile,
+ tile_pad=40,
+ pre_pad=0,
+ half=True) # need to set False in CPU mode
+ return bg_upsampler
+
+if __name__ == '__main__':
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--w', type=float, default=0.5, help='Balance the quality and fidelity')
+ parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
+ parser.add_argument('--test_path', type=str, default='./inputs/cropped_faces')
+ parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces')
+ parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
+ # large det_model: 'YOLOv5l', 'retinaface_resnet50'
+ # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
+ parser.add_argument('--detection_model', type=str, default='retinaface_resnet50')
+ parser.add_argument('--draw_box', action='store_true')
+ parser.add_argument('--bg_upsampler', type=str, default='None', help='background upsampler. Optional: realesrgan')
+ parser.add_argument('--face_upsample', action='store_true', help='face upsampler after enhancement.')
+ parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
+
+ args = parser.parse_args()
+
+ # ------------------------ input & output ------------------------
+ if args.test_path.endswith('/'): # solve when path ends with /
+ args.test_path = args.test_path[:-1]
+
+ w = args.w
+ result_root = f'results/{os.path.basename(args.test_path)}_{w}'
+
+ # ------------------ set up background upsampler ------------------
+ if args.bg_upsampler == 'realesrgan':
+ bg_upsampler = set_realesrgan()
+ else:
+ bg_upsampler = None
+
+ # ------------------ set up face upsampler ------------------
+ if args.face_upsample:
+ if bg_upsampler is not None:
+ face_upsampler = bg_upsampler
+ else:
+ face_upsampler = set_realesrgan()
+ else:
+ face_upsampler = None
+
+ # ------------------ set up CodeFormer restorer -------------------
+ net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
+ connect_list=['32', '64', '128', '256']).to(device)
+
+ # ckpt_path = 'weights/CodeFormer/codeformer.pth'
+ ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
+ model_dir='weights/CodeFormer', progress=True, file_name=None)
+ checkpoint = torch.load(ckpt_path)['params_ema']
+ net.load_state_dict(checkpoint)
+ net.eval()
+
+ # ------------------ set up FaceRestoreHelper -------------------
+ # large det_model: 'YOLOv5l', 'retinaface_resnet50'
+ # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
+ if not args.has_aligned:
+ print(f'Face detection model: {args.detection_model}')
+ if bg_upsampler is not None:
+ print(f'Background upsampling: True, Face upsampling: {args.face_upsample}')
+ else:
+ print(f'Background upsampling: False, Face upsampling: {args.face_upsample}')
+
+ face_helper = FaceRestoreHelper(
+ args.upscale,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model = args.detection_model,
+ save_ext='png',
+ use_parse=True,
+ device=device)
+
+ # -------------------- start to processing ---------------------
+ # scan all the jpg and png images
+ for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
+ # clean all the intermediate results to process the next image
+ face_helper.clean_all()
+
+ img_name = os.path.basename(img_path)
+ print(f'Processing: {img_name}')
+ basename, ext = os.path.splitext(img_name)
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
+
+ if args.has_aligned:
+ # the input faces are already cropped and aligned
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_helper.cropped_faces = [img]
+ else:
+ face_helper.read_image(img)
+ # get face landmarks for each face
+ num_det_faces = face_helper.get_face_landmarks_5(
+ only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
+ print(f'\tdetect {num_det_faces} faces')
+ # align and warp each face
+ face_helper.align_warp_face()
+
+ # face restoration for each cropped face
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
+ # prepare data
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
+
+ try:
+ with torch.no_grad():
+ output = net(cropped_face_t, w=w, adain=True)[0]
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+ del output
+ torch.cuda.empty_cache()
+ except Exception as error:
+ print(f'\tFailed inference for CodeFormer: {error}')
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+
+ restored_face = restored_face.astype('uint8')
+ face_helper.add_restored_face(restored_face)
+
+ # paste_back
+ if not args.has_aligned:
+ # upsample the background
+ if bg_upsampler is not None:
+ # Now only support RealESRGAN for upsampling background
+ bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0]
+ else:
+ bg_img = None
+ face_helper.get_inverse_affine(None)
+ # paste each restored face to the input image
+ if args.face_upsample and face_upsampler is not None:
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler)
+ else:
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)
+
+ # save faces
+ for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
+ # save cropped face
+ if not args.has_aligned:
+ save_crop_path = os.path.join(result_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
+ imwrite(cropped_face, save_crop_path)
+ # save restored face
+ if args.has_aligned:
+ save_face_name = f'{basename}.png'
+ else:
+ save_face_name = f'{basename}_{idx:02d}.png'
+ save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
+ imwrite(restored_face, save_restore_path)
+
+ # save restored img
+ if not args.has_aligned and restored_img is not None:
+ save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
+ imwrite(restored_img, save_restore_path)
+
+ print(f'\nAll results are saved in {result_root}')
diff --git a/repositories/CodeFormer/predict.py b/repositories/CodeFormer/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..01ecc6799dee747257167516bfcd66b98efec925
--- /dev/null
+++ b/repositories/CodeFormer/predict.py
@@ -0,0 +1,188 @@
+"""
+download checkpoints to ./weights beforehand
+python scripts/download_pretrained_models.py facelib
+python scripts/download_pretrained_models.py CodeFormer
+wget 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'
+"""
+
+import tempfile
+import cv2
+import torch
+from torchvision.transforms.functional import normalize
+from cog import BasePredictor, Input, Path
+
+from basicsr.utils import imwrite, img2tensor, tensor2img
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from basicsr.utils.realesrgan_utils import RealESRGANer
+from basicsr.utils.registry import ARCH_REGISTRY
+from facelib.utils.face_restoration_helper import FaceRestoreHelper
+
+
+class Predictor(BasePredictor):
+ def setup(self):
+ """Load the model into memory to make running multiple predictions efficient"""
+ self.device = "cuda:0"
+ self.bg_upsampler = set_realesrgan()
+ self.net = ARCH_REGISTRY.get("CodeFormer")(
+ dim_embd=512,
+ codebook_size=1024,
+ n_head=8,
+ n_layers=9,
+ connect_list=["32", "64", "128", "256"],
+ ).to(self.device)
+ ckpt_path = "weights/CodeFormer/codeformer.pth"
+ checkpoint = torch.load(ckpt_path)[
+ "params_ema"
+ ] # update file permission if cannot load
+ self.net.load_state_dict(checkpoint)
+ self.net.eval()
+
+ def predict(
+ self,
+ image: Path = Input(description="Input image"),
+ codeformer_fidelity: float = Input(
+ default=0.5,
+ ge=0,
+ le=1,
+ description="Balance the quality (lower number) and fidelity (higher number).",
+ ),
+ background_enhance: bool = Input(
+ description="Enhance background image with Real-ESRGAN", default=True
+ ),
+ face_upsample: bool = Input(
+ description="Upsample restored faces for high-resolution AI-created images",
+ default=True,
+ ),
+ upscale: int = Input(
+ description="The final upsampling scale of the image",
+ default=2,
+ ),
+ ) -> Path:
+ """Run a single prediction on the model"""
+
+ # take the default setting for the demo
+ has_aligned = False
+ only_center_face = False
+ draw_box = False
+ detection_model = "retinaface_resnet50"
+
+ self.face_helper = FaceRestoreHelper(
+ upscale,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model=detection_model,
+ save_ext="png",
+ use_parse=True,
+ device=self.device,
+ )
+
+ bg_upsampler = self.bg_upsampler if background_enhance else None
+ face_upsampler = self.bg_upsampler if face_upsample else None
+
+ img = cv2.imread(str(image), cv2.IMREAD_COLOR)
+
+ if has_aligned:
+ # the input faces are already cropped and aligned
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
+ self.face_helper.cropped_faces = [img]
+ else:
+ self.face_helper.read_image(img)
+ # get face landmarks for each face
+ num_det_faces = self.face_helper.get_face_landmarks_5(
+ only_center_face=only_center_face, resize=640, eye_dist_threshold=5
+ )
+ print(f"\tdetect {num_det_faces} faces")
+ # align and warp each face
+ self.face_helper.align_warp_face()
+
+ # face restoration for each cropped face
+ for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
+ # prepare data
+ cropped_face_t = img2tensor(
+ cropped_face / 255.0, bgr2rgb=True, float32=True
+ )
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
+
+ try:
+ with torch.no_grad():
+ output = self.net(
+ cropped_face_t, w=codeformer_fidelity, adain=True
+ )[0]
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+ del output
+ torch.cuda.empty_cache()
+ except Exception as error:
+ print(f"\tFailed inference for CodeFormer: {error}")
+ restored_face = tensor2img(
+ cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
+ )
+
+ restored_face = restored_face.astype("uint8")
+ self.face_helper.add_restored_face(restored_face)
+
+ # paste_back
+ if not has_aligned:
+ # upsample the background
+ if bg_upsampler is not None:
+ # Now only support RealESRGAN for upsampling background
+ bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
+ else:
+ bg_img = None
+ self.face_helper.get_inverse_affine(None)
+ # paste each restored face to the input image
+ if face_upsample and face_upsampler is not None:
+ restored_img = self.face_helper.paste_faces_to_input_image(
+ upsample_img=bg_img,
+ draw_box=draw_box,
+ face_upsampler=face_upsampler,
+ )
+ else:
+ restored_img = self.face_helper.paste_faces_to_input_image(
+ upsample_img=bg_img, draw_box=draw_box
+ )
+
+ # save restored img
+ out_path = Path(tempfile.mkdtemp()) / "output.png"
+
+ if not has_aligned and restored_img is not None:
+ imwrite(restored_img, str(out_path))
+
+ return out_path
+
+
+def imread(img_path):
+ img = cv2.imread(img_path)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return img
+
+
+def set_realesrgan():
+ if not torch.cuda.is_available(): # CPU
+ import warnings
+
+ warnings.warn(
+ "The unoptimized RealESRGAN is slow on CPU. We do not use it. "
+ "If you really want to use it, please modify the corresponding codes.",
+ category=RuntimeWarning,
+ )
+ bg_upsampler = None
+ else:
+ model = RRDBNet(
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_block=23,
+ num_grow_ch=32,
+ scale=2,
+ )
+ bg_upsampler = RealESRGANer(
+ scale=2,
+ model_path="./weights/RealESRGAN_x2plus.pth",
+ model=model,
+ tile=400,
+ tile_pad=40,
+ pre_pad=0,
+ half=True,
+ )
+ return bg_upsampler
diff --git a/repositories/CodeFormer/requirements.txt b/repositories/CodeFormer/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f97dfde85ebe83708fc1f6f7234a0ef69f18bde5
--- /dev/null
+++ b/repositories/CodeFormer/requirements.txt
@@ -0,0 +1,20 @@
+addict
+future
+lmdb
+numpy
+opencv-python
+Pillow
+pyyaml
+requests
+scikit-image
+scipy
+tb-nightly
+torch>=1.7.1
+torchvision
+tqdm
+yapf
+lpips
+gdown # supports downloading the large file from Google Drive
+# cmake
+# dlib
+# conda install -c conda-forge dlib
\ No newline at end of file
diff --git a/repositories/CodeFormer/scripts/crop_align_face.py b/repositories/CodeFormer/scripts/crop_align_face.py
new file mode 100644
index 0000000000000000000000000000000000000000..31e66266ac0e5f818fa18b6409993151086bbc8b
--- /dev/null
+++ b/repositories/CodeFormer/scripts/crop_align_face.py
@@ -0,0 +1,192 @@
+"""
+brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
+author: lzhbrian (https://lzhbrian.me)
+link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5
+date: 2020.1.5
+note: code is heavily borrowed from
+ https://github.com/NVlabs/ffhq-dataset
+ http://dlib.net/face_landmark_detection.py.html
+requirements:
+ conda install Pillow numpy scipy
+ conda install -c conda-forge dlib
+ # download face landmark model from:
+ # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
+"""
+
+import cv2
+import dlib
+import glob
+import numpy as np
+import os
+import PIL
+import PIL.Image
+import scipy
+import scipy.ndimage
+import sys
+import argparse
+
+# download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
+predictor = dlib.shape_predictor('weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat')
+
+
+def get_landmark(filepath, only_keep_largest=True):
+ """get landmark with dlib
+ :return: np.array shape=(68, 2)
+ """
+ detector = dlib.get_frontal_face_detector()
+
+ img = dlib.load_rgb_image(filepath)
+ dets = detector(img, 1)
+
+ # Shangchen modified
+ print("Number of faces detected: {}".format(len(dets)))
+ if only_keep_largest:
+ print('Detect several faces and only keep the largest.')
+ face_areas = []
+ for k, d in enumerate(dets):
+ face_area = (d.right() - d.left()) * (d.bottom() - d.top())
+ face_areas.append(face_area)
+
+ largest_idx = face_areas.index(max(face_areas))
+ d = dets[largest_idx]
+ shape = predictor(img, d)
+ print("Part 0: {}, Part 1: {} ...".format(
+ shape.part(0), shape.part(1)))
+ else:
+ for k, d in enumerate(dets):
+ print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
+ k, d.left(), d.top(), d.right(), d.bottom()))
+ # Get the landmarks/parts for the face in box d.
+ shape = predictor(img, d)
+ print("Part 0: {}, Part 1: {} ...".format(
+ shape.part(0), shape.part(1)))
+
+ t = list(shape.parts())
+ a = []
+ for tt in t:
+ a.append([tt.x, tt.y])
+ lm = np.array(a)
+ # lm is a shape=(68,2) np.array
+ return lm
+
+def align_face(filepath, out_path):
+ """
+ :param filepath: str
+ :return: PIL Image
+ """
+ try:
+ lm = get_landmark(filepath)
+ except:
+ print('No landmark ...')
+ return
+
+ lm_chin = lm[0:17] # left-right
+ lm_eyebrow_left = lm[17:22] # left-right
+ lm_eyebrow_right = lm[22:27] # left-right
+ lm_nose = lm[27:31] # top-down
+ lm_nostrils = lm[31:36] # top-down
+ lm_eye_left = lm[36:42] # left-clockwise
+ lm_eye_right = lm[42:48] # left-clockwise
+ lm_mouth_outer = lm[48:60] # left-clockwise
+ lm_mouth_inner = lm[60:68] # left-clockwise
+
+ # Calculate auxiliary vectors.
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ mouth_left = lm_mouth_outer[0]
+ mouth_right = lm_mouth_outer[6]
+ mouth_avg = (mouth_left + mouth_right) * 0.5
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Choose oriented crop rectangle.
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ x /= np.hypot(*x)
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
+ y = np.flipud(x) * [-1, 1]
+ c = eye_avg + eye_to_mouth * 0.1
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ qsize = np.hypot(*x) * 2
+
+ # read image
+ img = PIL.Image.open(filepath)
+
+ output_size = 512
+ transform_size = 4096
+ enable_padding = False
+
+ # Shrink.
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ rsize = (int(np.rint(float(img.size[0]) / shrink)),
+ int(np.rint(float(img.size[1]) / shrink)))
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop.
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
+ int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
+ min(crop[2] + border,
+ img.size[0]), min(crop[3] + border, img.size[1]))
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+ img = img.crop(crop)
+ quad -= crop[0:2]
+
+ # Pad.
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
+ int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border,
+ 0), max(-pad[1] + border,
+ 0), max(pad[2] - img.size[0] + border,
+ 0), max(pad[3] - img.size[1] + border, 0))
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(
+ np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
+ 'reflect')
+ h, w, _ = img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(
+ 1.0 -
+ np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]), 1.0 -
+ np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = qsize * 0.02
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
+ img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = PIL.Image.fromarray(
+ np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+ quad += pad[:2]
+
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
+ (quad + 0.5).flatten(), PIL.Image.BILINEAR)
+
+ if output_size < transform_size:
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
+
+ # Save aligned image.
+ print('saveing: ', out_path)
+ img.save(out_path)
+
+ return img, np.max(quad[:, 0]) - np.min(quad[:, 0])
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--in_dir', type=str, default='./inputs/whole_imgs')
+ parser.add_argument('--out_dir', type=str, default='./inputs/cropped_faces')
+ args = parser.parse_args()
+
+ img_list = sorted(glob.glob(f'{args.in_dir}/*.png'))
+ img_list = sorted(img_list)
+
+ for in_path in img_list:
+ out_path = os.path.join(args.out_dir, in_path.split("/")[-1])
+ out_path = out_path.replace('.jpg', '.png')
+ size_ = align_face(in_path, out_path)
\ No newline at end of file
diff --git a/repositories/CodeFormer/scripts/download_pretrained_models.py b/repositories/CodeFormer/scripts/download_pretrained_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..daa6e8ca14ea91c89a318e85d9f182eb7d1bf025
--- /dev/null
+++ b/repositories/CodeFormer/scripts/download_pretrained_models.py
@@ -0,0 +1,40 @@
+import argparse
+import os
+from os import path as osp
+
+from basicsr.utils.download_util import load_file_from_url
+
+
+def download_pretrained_models(method, file_urls):
+ save_path_root = f'./weights/{method}'
+ os.makedirs(save_path_root, exist_ok=True)
+
+ for file_name, file_url in file_urls.items():
+ save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ 'method',
+ type=str,
+ help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
+ args = parser.parse_args()
+
+ file_urls = {
+ 'CodeFormer': {
+ 'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+ },
+ 'facelib': {
+ # 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth',
+ 'detection_Resnet50_Final.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
+ 'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
+ }
+ }
+
+ if args.method == 'all':
+ for method in file_urls.keys():
+ download_pretrained_models(method, file_urls[method])
+ else:
+ download_pretrained_models(args.method, file_urls[args.method])
\ No newline at end of file
diff --git a/repositories/CodeFormer/scripts/download_pretrained_models_from_gdrive.py b/repositories/CodeFormer/scripts/download_pretrained_models_from_gdrive.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df5be6fc260394ee9bbd0a7ae377e2ca657fe83
--- /dev/null
+++ b/repositories/CodeFormer/scripts/download_pretrained_models_from_gdrive.py
@@ -0,0 +1,60 @@
+import argparse
+import os
+from os import path as osp
+
+# from basicsr.utils.download_util import download_file_from_google_drive
+import gdown
+
+
+def download_pretrained_models(method, file_ids):
+ save_path_root = f'./weights/{method}'
+ os.makedirs(save_path_root, exist_ok=True)
+
+ for file_name, file_id in file_ids.items():
+ file_url = 'https://drive.google.com/uc?id='+file_id
+ save_path = osp.abspath(osp.join(save_path_root, file_name))
+ if osp.exists(save_path):
+ user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
+ if user_response.lower() == 'y':
+ print(f'Covering {file_name} to {save_path}')
+ gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+ elif user_response.lower() == 'n':
+ print(f'Skipping {file_name}')
+ else:
+ raise ValueError('Wrong input. Only accepts Y/N.')
+ else:
+ print(f'Downloading {file_name} to {save_path}')
+ gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ 'method',
+ type=str,
+ help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
+ args = parser.parse_args()
+
+ # file name: file id
+ # 'dlib': {
+ # 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
+ # 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
+ # 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
+ # }
+ file_ids = {
+ 'CodeFormer': {
+ 'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB'
+ },
+ 'facelib': {
+ 'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV',
+ 'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK'
+ }
+ }
+
+ if args.method == 'all':
+ for method in file_ids.keys():
+ download_pretrained_models(method, file_ids[method])
+ else:
+ download_pretrained_models(args.method, file_ids[args.method])
\ No newline at end of file
diff --git a/repositories/CodeFormer/weights/CodeFormer/.gitkeep b/repositories/CodeFormer/weights/CodeFormer/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/CodeFormer/weights/README.md b/repositories/CodeFormer/weights/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..67ad334bd672eeb9f82813cd54e8885331bbb2f2
--- /dev/null
+++ b/repositories/CodeFormer/weights/README.md
@@ -0,0 +1,3 @@
+# Weights
+
+Put the downloaded pre-trained models to this folder.
\ No newline at end of file
diff --git a/repositories/CodeFormer/weights/facelib/.gitkeep b/repositories/CodeFormer/weights/facelib/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/k-diffusion/.github/workflows/python-publish.yml b/repositories/k-diffusion/.github/workflows/python-publish.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9638b6b4575cb6544bf45429ad86899262a50762
--- /dev/null
+++ b/repositories/k-diffusion/.github/workflows/python-publish.yml
@@ -0,0 +1,37 @@
+name: Release
+
+on:
+ push:
+ branches:
+ - master
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - uses: actions-ecosystem/action-regex-match@v2
+ id: regex-match
+ with:
+ text: ${{ github.event.head_commit.message }}
+ regex: '^Release ([^ ]+)'
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.8'
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install setuptools wheel twine
+ - name: Release
+ if: ${{ steps.regex-match.outputs.match != '' }}
+ uses: softprops/action-gh-release@v1
+ with:
+ tag_name: v${{ steps.regex-match.outputs.group1 }}
+ - name: Build and publish
+ if: ${{ steps.regex-match.outputs.match != '' }}
+ env:
+ TWINE_USERNAME: __token__
+ TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
+ run: |
+ python setup.py sdist bdist_wheel
+ twine upload dist/*
diff --git a/repositories/k-diffusion/.gitignore b/repositories/k-diffusion/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e647ac957b0cfb9b4d10af4829e84f0f1ca0b2f0
--- /dev/null
+++ b/repositories/k-diffusion/.gitignore
@@ -0,0 +1,10 @@
+venv*
+__pycache__
+.ipynb_checkpoints
+*.pth
+*.egg-info
+data
+*_demo_*.png
+wandb/*
+*.csv
+.env
\ No newline at end of file
diff --git a/repositories/k-diffusion/LICENSE b/repositories/k-diffusion/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..37a42365ab6670d39334fc9b650975c819fee3d1
--- /dev/null
+++ b/repositories/k-diffusion/LICENSE
@@ -0,0 +1,19 @@
+Copyright (c) 2022 Katherine Crowson
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/repositories/k-diffusion/README.md b/repositories/k-diffusion/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4f7c92f9a947da69dedecd8cc1149b3713ab2e80
--- /dev/null
+++ b/repositories/k-diffusion/README.md
@@ -0,0 +1,61 @@
+# k-diffusion
+
+An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. The patching method in [Improving Diffusion Model Efficiency Through Patching](https://arxiv.org/abs/2207.04316) is implemented as well.
+
+## Installation
+
+`k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run `pip install -e `.
+
+## Training:
+
+To train models:
+
+```sh
+$ ./train.py --config CONFIG_FILE --name RUN_NAME
+```
+
+For instance, to train a model on MNIST:
+
+```sh
+$ ./train.py --config configs/config_mnist.json --name RUN_NAME
+```
+
+The configuration file allows you to specify the dataset type. Currently supported types are `"imagefolder"` (finds all images in that folder and its subfolders, recursively), `"cifar10"` (CIFAR-10), and `"mnist"` (MNIST). `"huggingface"` [Hugging Face Datasets](https://huggingface.co/docs/datasets/index) is also supported.
+
+Multi-GPU and multi-node training is supported with [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/index). You can configure Accelerate by running:
+
+```sh
+$ accelerate config
+```
+
+on all nodes, then running:
+
+```sh
+$ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME
+```
+
+on all nodes.
+
+## Enhancements/additional features:
+
+- k-diffusion supports an experimental model output type, an isotropic Gaussian, which seems to have a lower gradient noise scale and to train faster than Karras et al. (2022) diffusion models.
+
+- k-diffusion has wrappers for [v-diffusion-pytorch](https://github.com/crowsonkb/v-diffusion-pytorch), [OpenAI diffusion](https://github.com/openai/guided-diffusion), and [CompVis diffusion](https://github.com/CompVis/latent-diffusion) models allowing them to be used with its samplers and ODE/SDE.
+
+- k-diffusion models support progressive growing.
+
+- k-diffusion implements [DPM-Solver](https://arxiv.org/abs/2206.00927), which produces higher quality samples at the same number of function evalutions as Karras Algorithm 2, as well as supporting adaptive step size control. [DPM-Solver++(2S) and (2M)](https://arxiv.org/abs/2211.01095) are implemented now too for improved quality with low numbers of steps.
+
+- k-diffusion supports [CLIP](https://openai.com/blog/clip/) guided sampling from unconditional diffusion models (see `sample_clip_guided.py`).
+
+- k-diffusion supports log likelihood calculation (not a variational lower bound) for native models and all wrapped models.
+
+- k-diffusion can calculate, during training, the [FID](https://papers.nips.cc/paper/2017/file/8a1d694707eb0fefe65871369074926d-Paper.pdf) and [KID](https://arxiv.org/abs/1801.01401) vs the training set.
+
+- k-diffusion can calculate, during training, the gradient noise scale (1 / SNR), from _An Empirical Model of Large-Batch Training_, https://arxiv.org/abs/1812.06162).
+
+## To do:
+
+- Anything except unconditional image diffusion models
+
+- Latent diffusion
diff --git a/repositories/k-diffusion/configs/config_32x32_small.json b/repositories/k-diffusion/configs/config_32x32_small.json
new file mode 100644
index 0000000000000000000000000000000000000000..28bacb6274ff278185031fd49db0ea39d93b1299
--- /dev/null
+++ b/repositories/k-diffusion/configs/config_32x32_small.json
@@ -0,0 +1,46 @@
+{
+ "model": {
+ "type": "image_v1",
+ "input_channels": 3,
+ "input_size": [32, 32],
+ "patch_size": 1,
+ "mapping_out": 256,
+ "depths": [2, 4, 4],
+ "channels": [128, 256, 512],
+ "self_attn_depths": [false, true, true],
+ "has_variance": true,
+ "dropout_rate": 0.05,
+ "augment_wrapper": true,
+ "augment_prob": 0.12,
+ "sigma_data": 0.5,
+ "sigma_min": 1e-2,
+ "sigma_max": 80,
+ "sigma_sample_density": {
+ "type": "lognormal",
+ "mean": -1.2,
+ "std": 1.2
+ }
+ },
+ "dataset": {
+ "type": "imagefolder",
+ "location": "/path/to/dataset"
+ },
+ "optimizer": {
+ "type": "adamw",
+ "lr": 1e-4,
+ "betas": [0.95, 0.999],
+ "eps": 1e-6,
+ "weight_decay": 1e-3
+ },
+ "lr_sched": {
+ "type": "inverse",
+ "inv_gamma": 20000.0,
+ "power": 1.0,
+ "warmup": 0.99
+ },
+ "ema_sched": {
+ "type": "inverse",
+ "power": 0.6667,
+ "max_value": 0.9999
+ }
+}
diff --git a/repositories/k-diffusion/configs/config_32x32_small_butterflies.json b/repositories/k-diffusion/configs/config_32x32_small_butterflies.json
new file mode 100644
index 0000000000000000000000000000000000000000..de02c8e2e3f7b21c3dc305eb2fd8eeb144b489cf
--- /dev/null
+++ b/repositories/k-diffusion/configs/config_32x32_small_butterflies.json
@@ -0,0 +1,47 @@
+{
+ "model": {
+ "type": "image_v1",
+ "input_channels": 3,
+ "input_size": [32, 32],
+ "patch_size": 1,
+ "mapping_out": 256,
+ "depths": [2, 4, 4],
+ "channels": [128, 256, 512],
+ "self_attn_depths": [false, true, true],
+ "has_variance": true,
+ "dropout_rate": 0.05,
+ "augment_wrapper": true,
+ "augment_prob": 0.12,
+ "sigma_data": 0.5,
+ "sigma_min": 1e-2,
+ "sigma_max": 80,
+ "sigma_sample_density": {
+ "type": "lognormal",
+ "mean": -1.2,
+ "std": 1.2
+ }
+ },
+ "dataset": {
+ "type": "huggingface",
+ "location": "huggan/smithsonian_butterflies_subset",
+ "image_key": "image"
+ },
+ "optimizer": {
+ "type": "adamw",
+ "lr": 1e-4,
+ "betas": [0.95, 0.999],
+ "eps": 1e-6,
+ "weight_decay": 1e-3
+ },
+ "lr_sched": {
+ "type": "inverse",
+ "inv_gamma": 20000.0,
+ "power": 1.0,
+ "warmup": 0.99
+ },
+ "ema_sched": {
+ "type": "inverse",
+ "power": 0.6667,
+ "max_value": 0.9999
+ }
+}
diff --git a/repositories/k-diffusion/configs/config_cifar10.json b/repositories/k-diffusion/configs/config_cifar10.json
new file mode 100644
index 0000000000000000000000000000000000000000..f719fd7faf4a88152e50cb2da4487178a82483ae
--- /dev/null
+++ b/repositories/k-diffusion/configs/config_cifar10.json
@@ -0,0 +1,46 @@
+{
+ "model": {
+ "type": "image_v1",
+ "input_channels": 3,
+ "input_size": [32, 32],
+ "patch_size": 1,
+ "mapping_out": 256,
+ "depths": [2, 4, 4],
+ "channels": [128, 256, 512],
+ "self_attn_depths": [false, true, true],
+ "has_variance": true,
+ "dropout_rate": 0.05,
+ "augment_wrapper": true,
+ "augment_prob": 0.12,
+ "sigma_data": 0.5,
+ "sigma_min": 1e-2,
+ "sigma_max": 80,
+ "sigma_sample_density": {
+ "type": "lognormal",
+ "mean": -1.2,
+ "std": 1.2
+ }
+ },
+ "dataset": {
+ "type": "cifar10",
+ "location": "data"
+ },
+ "optimizer": {
+ "type": "adamw",
+ "lr": 1e-4,
+ "betas": [0.95, 0.999],
+ "eps": 1e-6,
+ "weight_decay": 1e-3
+ },
+ "lr_sched": {
+ "type": "inverse",
+ "inv_gamma": 20000.0,
+ "power": 1.0,
+ "warmup": 0.99
+ },
+ "ema_sched": {
+ "type": "inverse",
+ "power": 0.6667,
+ "max_value": 0.9999
+ }
+}
diff --git a/repositories/k-diffusion/configs/config_mnist.json b/repositories/k-diffusion/configs/config_mnist.json
new file mode 100644
index 0000000000000000000000000000000000000000..c698913b0aa5db455b5191477b2c86b4edec88b0
--- /dev/null
+++ b/repositories/k-diffusion/configs/config_mnist.json
@@ -0,0 +1,46 @@
+{
+ "model": {
+ "type": "image_v1",
+ "input_channels": 1,
+ "input_size": [28, 28],
+ "patch_size": 1,
+ "mapping_out": 256,
+ "depths": [2, 4, 4],
+ "channels": [128, 128, 256],
+ "self_attn_depths": [false, false, true],
+ "has_variance": true,
+ "dropout_rate": 0.05,
+ "augment_wrapper": true,
+ "augment_prob": 0.12,
+ "sigma_data": 0.6162,
+ "sigma_min": 1e-2,
+ "sigma_max": 80,
+ "sigma_sample_density": {
+ "type": "lognormal",
+ "mean": -1.2,
+ "std": 1.2
+ }
+ },
+ "dataset": {
+ "type": "mnist",
+ "location": "data"
+ },
+ "optimizer": {
+ "type": "adamw",
+ "lr": 2e-4,
+ "betas": [0.95, 0.999],
+ "eps": 1e-6,
+ "weight_decay": 1e-3
+ },
+ "lr_sched": {
+ "type": "inverse",
+ "inv_gamma": 20000.0,
+ "power": 1.0,
+ "warmup": 0.99
+ },
+ "ema_sched": {
+ "type": "inverse",
+ "power": 0.6667,
+ "max_value": 0.9999
+ }
+}
diff --git a/repositories/k-diffusion/k_diffusion/__init__.py b/repositories/k-diffusion/k_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5de9decab9fef99f2dd152f16b82b5806508ffdf
--- /dev/null
+++ b/repositories/k-diffusion/k_diffusion/__init__.py
@@ -0,0 +1,2 @@
+from . import augmentation, config, evaluation, external, gns, layers, models, sampling, utils
+from .layers import Denoiser
diff --git a/repositories/k-diffusion/k_diffusion/augmentation.py b/repositories/k-diffusion/k_diffusion/augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dd17c686300c8ecba7fac134aa54f01619c3d46
--- /dev/null
+++ b/repositories/k-diffusion/k_diffusion/augmentation.py
@@ -0,0 +1,105 @@
+from functools import reduce
+import math
+import operator
+
+import numpy as np
+from skimage import transform
+import torch
+from torch import nn
+
+
+def translate2d(tx, ty):
+ mat = [[1, 0, tx],
+ [0, 1, ty],
+ [0, 0, 1]]
+ return torch.tensor(mat, dtype=torch.float32)
+
+
+def scale2d(sx, sy):
+ mat = [[sx, 0, 0],
+ [ 0, sy, 0],
+ [ 0, 0, 1]]
+ return torch.tensor(mat, dtype=torch.float32)
+
+
+def rotate2d(theta):
+ mat = [[torch.cos(theta), torch.sin(-theta), 0],
+ [torch.sin(theta), torch.cos(theta), 0],
+ [ 0, 0, 1]]
+ return torch.tensor(mat, dtype=torch.float32)
+
+
+class KarrasAugmentationPipeline:
+ def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
+ self.a_prob = a_prob
+ self.a_scale = a_scale
+ self.a_aniso = a_aniso
+ self.a_trans = a_trans
+
+ def __call__(self, image):
+ h, w = image.size
+ mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
+
+ # x-flip
+ a0 = torch.randint(2, []).float()
+ mats.append(scale2d(1 - 2 * a0, 1))
+ # y-flip
+ do = (torch.rand([]) < self.a_prob).float()
+ a1 = torch.randint(2, []).float() * do
+ mats.append(scale2d(1, 1 - 2 * a1))
+ # scaling
+ do = (torch.rand([]) < self.a_prob).float()
+ a2 = torch.randn([]) * do
+ mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
+ # rotation
+ do = (torch.rand([]) < self.a_prob).float()
+ a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
+ mats.append(rotate2d(-a3))
+ # anisotropy
+ do = (torch.rand([]) < self.a_prob).float()
+ a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
+ a5 = torch.randn([]) * do
+ mats.append(rotate2d(a4))
+ mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
+ mats.append(rotate2d(-a4))
+ # translation
+ do = (torch.rand([]) < self.a_prob).float()
+ a6 = torch.randn([]) * do
+ a7 = torch.randn([]) * do
+ mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
+
+ # form the transformation matrix and conditioning vector
+ mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
+ mat = reduce(operator.matmul, mats)
+ cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])
+
+ # apply the transformation
+ image_orig = np.array(image, dtype=np.float32) / 255
+ if image_orig.ndim == 2:
+ image_orig = image_orig[..., None]
+ tf = transform.AffineTransform(mat.numpy())
+ image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
+ image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
+ image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
+ return image, image_orig, cond
+
+
+class KarrasAugmentWrapper(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.inner_model = model
+
+ def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
+ if aug_cond is None:
+ aug_cond = input.new_zeros([input.shape[0], 9])
+ if mapping_cond is None:
+ mapping_cond = aug_cond
+ else:
+ mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
+ return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
+
+ def set_skip_stages(self, skip_stages):
+ return self.inner_model.set_skip_stages(skip_stages)
+
+ def set_patch_size(self, patch_size):
+ return self.inner_model.set_patch_size(patch_size)
diff --git a/repositories/k-diffusion/k_diffusion/config.py b/repositories/k-diffusion/k_diffusion/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b504d6d74b2fbdf92be6aa6f84955832f8c701a
--- /dev/null
+++ b/repositories/k-diffusion/k_diffusion/config.py
@@ -0,0 +1,110 @@
+from functools import partial
+import json
+import math
+import warnings
+
+from jsonmerge import merge
+
+from . import augmentation, layers, models, utils
+
+
+def load_config(file):
+ defaults = {
+ 'model': {
+ 'sigma_data': 1.,
+ 'patch_size': 1,
+ 'dropout_rate': 0.,
+ 'augment_wrapper': True,
+ 'augment_prob': 0.,
+ 'mapping_cond_dim': 0,
+ 'unet_cond_dim': 0,
+ 'cross_cond_dim': 0,
+ 'cross_attn_depths': None,
+ 'skip_stages': 0,
+ 'has_variance': False,
+ },
+ 'dataset': {
+ 'type': 'imagefolder',
+ },
+ 'optimizer': {
+ 'type': 'adamw',
+ 'lr': 1e-4,
+ 'betas': [0.95, 0.999],
+ 'eps': 1e-6,
+ 'weight_decay': 1e-3,
+ },
+ 'lr_sched': {
+ 'type': 'inverse',
+ 'inv_gamma': 20000.,
+ 'power': 1.,
+ 'warmup': 0.99,
+ },
+ 'ema_sched': {
+ 'type': 'inverse',
+ 'power': 0.6667,
+ 'max_value': 0.9999
+ },
+ }
+ config = json.load(file)
+ return merge(defaults, config)
+
+
+def make_model(config):
+ config = config['model']
+ assert config['type'] == 'image_v1'
+ model = models.ImageDenoiserModelV1(
+ config['input_channels'],
+ config['mapping_out'],
+ config['depths'],
+ config['channels'],
+ config['self_attn_depths'],
+ config['cross_attn_depths'],
+ patch_size=config['patch_size'],
+ dropout_rate=config['dropout_rate'],
+ mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
+ unet_cond_dim=config['unet_cond_dim'],
+ cross_cond_dim=config['cross_cond_dim'],
+ skip_stages=config['skip_stages'],
+ has_variance=config['has_variance'],
+ )
+ if config['augment_wrapper']:
+ model = augmentation.KarrasAugmentWrapper(model)
+ return model
+
+
+def make_denoiser_wrapper(config):
+ config = config['model']
+ sigma_data = config.get('sigma_data', 1.)
+ has_variance = config.get('has_variance', False)
+ if not has_variance:
+ return partial(layers.Denoiser, sigma_data=sigma_data)
+ return partial(layers.DenoiserWithVariance, sigma_data=sigma_data)
+
+
+def make_sample_density(config):
+ sd_config = config['sigma_sample_density']
+ sigma_data = config['sigma_data']
+ if sd_config['type'] == 'lognormal':
+ loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
+ scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
+ return partial(utils.rand_log_normal, loc=loc, scale=scale)
+ if sd_config['type'] == 'loglogistic':
+ loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
+ scale = sd_config['scale'] if 'scale' in sd_config else 0.5
+ min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
+ max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
+ return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
+ if sd_config['type'] == 'loguniform':
+ min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
+ max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
+ return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
+ if sd_config['type'] == 'v-diffusion':
+ min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
+ max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
+ return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
+ if sd_config['type'] == 'split-lognormal':
+ loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
+ scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
+ scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
+ return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
+ raise ValueError('Unknown sample density type')
diff --git a/repositories/k-diffusion/k_diffusion/evaluation.py b/repositories/k-diffusion/k_diffusion/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c34bbf1656854d9cf233b7620b684e44b30de82
--- /dev/null
+++ b/repositories/k-diffusion/k_diffusion/evaluation.py
@@ -0,0 +1,134 @@
+import math
+import os
+from pathlib import Path
+
+from cleanfid.inception_torchscript import InceptionV3W
+import clip
+from resize_right import resize
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchvision import transforms
+from tqdm.auto import trange
+
+from . import utils
+
+
+class InceptionV3FeatureExtractor(nn.Module):
+ def __init__(self, device='cpu'):
+ super().__init__()
+ path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
+ url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
+ digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
+ utils.download_file(path / 'inception-2015-12-05.pt', url, digest)
+ self.model = InceptionV3W(str(path), resize_inside=False).to(device)
+ self.size = (299, 299)
+
+ def forward(self, x):
+ if x.shape[2:4] != self.size:
+ x = resize(x, out_shape=self.size, pad_mode='reflect')
+ if x.shape[1] == 1:
+ x = torch.cat([x] * 3, dim=1)
+ x = (x * 127.5 + 127.5).clamp(0, 255)
+ return self.model(x)
+
+
+class CLIPFeatureExtractor(nn.Module):
+ def __init__(self, name='ViT-L/14@336px', device='cpu'):
+ super().__init__()
+ self.model = clip.load(name, device=device)[0].eval().requires_grad_(False)
+ self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711))
+ self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution)
+
+ def forward(self, x):
+ if x.shape[2:4] != self.size:
+ x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1)
+ x = self.normalize(x)
+ x = self.model.encode_image(x).float()
+ x = F.normalize(x) * x.shape[1] ** 0.5
+ return x
+
+
+def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size):
+ n_per_proc = math.ceil(n / accelerator.num_processes)
+ feats_all = []
+ try:
+ for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process):
+ cur_batch_size = min(n - i, batch_size)
+ samples = sample_fn(cur_batch_size)[:cur_batch_size]
+ feats_all.append(accelerator.gather(extractor_fn(samples)))
+ except StopIteration:
+ pass
+ return torch.cat(feats_all)[:n]
+
+
+def polynomial_kernel(x, y):
+ d = x.shape[-1]
+ dot = x @ y.transpose(-2, -1)
+ return (dot / d + 1) ** 3
+
+
+def squared_mmd(x, y, kernel=polynomial_kernel):
+ m = x.shape[-2]
+ n = y.shape[-2]
+ kxx = kernel(x, x)
+ kyy = kernel(y, y)
+ kxy = kernel(x, y)
+ kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1)
+ kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1)
+ kxy_sum = kxy.sum([-1, -2])
+ term_1 = kxx_sum / m / (m - 1)
+ term_2 = kyy_sum / n / (n - 1)
+ term_3 = kxy_sum * 2 / m / n
+ return term_1 + term_2 - term_3
+
+
+@utils.tf32_mode(matmul=False)
+def kid(x, y, max_size=5000):
+ x_size, y_size = x.shape[0], y.shape[0]
+ n_partitions = math.ceil(max(x_size / max_size, y_size / max_size))
+ total_mmd = x.new_zeros([])
+ for i in range(n_partitions):
+ cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)]
+ cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)]
+ total_mmd = total_mmd + squared_mmd(cur_x, cur_y)
+ return total_mmd / n_partitions
+
+
+class _MatrixSquareRootEig(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, a):
+ vals, vecs = torch.linalg.eigh(a)
+ ctx.save_for_backward(vals, vecs)
+ return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ vals, vecs = ctx.saved_tensors
+ d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
+ vecs_t = vecs.transpose(-2, -1)
+ return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
+
+
+def sqrtm_eig(a):
+ if a.ndim < 2:
+ raise RuntimeError('tensor of matrices must have at least 2 dimensions')
+ if a.shape[-2] != a.shape[-1]:
+ raise RuntimeError('tensor must be batches of square matrices')
+ return _MatrixSquareRootEig.apply(a)
+
+
+@utils.tf32_mode(matmul=False)
+def fid(x, y, eps=1e-8):
+ x_mean = x.mean(dim=0)
+ y_mean = y.mean(dim=0)
+ mean_term = (x_mean - y_mean).pow(2).sum()
+ x_cov = torch.cov(x.T)
+ y_cov = torch.cov(y.T)
+ eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps
+ x_cov = x_cov + eps_eye
+ y_cov = y_cov + eps_eye
+ x_cov_sqrt = sqrtm_eig(x_cov)
+ cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt))
+ return mean_term + cov_term
diff --git a/repositories/k-diffusion/k_diffusion/external.py b/repositories/k-diffusion/k_diffusion/external.py
new file mode 100644
index 0000000000000000000000000000000000000000..79b51cec41c52f054775f26c26cf63414d588aef
--- /dev/null
+++ b/repositories/k-diffusion/k_diffusion/external.py
@@ -0,0 +1,177 @@
+import math
+
+import torch
+from torch import nn
+
+from . import sampling, utils
+
+
+class VDenoiser(nn.Module):
+ """A v-diffusion-pytorch model wrapper for k-diffusion."""
+
+ def __init__(self, inner_model):
+ super().__init__()
+ self.inner_model = inner_model
+ self.sigma_data = 1.
+
+ def get_scalings(self, sigma):
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
+ c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ return c_skip, c_out, c_in
+
+ def sigma_to_t(self, sigma):
+ return sigma.atan() / math.pi * 2
+
+ def t_to_sigma(self, t):
+ return (t * math.pi / 2).tan()
+
+ def loss(self, input, noise, sigma, **kwargs):
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ noised_input = input + noise * utils.append_dims(sigma, input.ndim)
+ model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
+ target = (input - c_skip * noised_input) / c_out
+ return (model_output - target).pow(2).flatten(1).mean(1)
+
+ def forward(self, input, sigma, **kwargs):
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
+
+
+class DiscreteSchedule(nn.Module):
+ """A mapping between continuous noise levels (sigmas) and a list of discrete noise
+ levels."""
+
+ def __init__(self, sigmas, quantize):
+ super().__init__()
+ self.register_buffer('sigmas', sigmas)
+ self.register_buffer('log_sigmas', sigmas.log())
+ self.quantize = quantize
+
+ @property
+ def sigma_min(self):
+ return self.sigmas[0]
+
+ @property
+ def sigma_max(self):
+ return self.sigmas[-1]
+
+ def get_sigmas(self, n=None):
+ if n is None:
+ return sampling.append_zero(self.sigmas.flip(0))
+ t_max = len(self.sigmas) - 1
+ t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
+ return sampling.append_zero(self.t_to_sigma(t))
+
+ def sigma_to_t(self, sigma, quantize=None):
+ quantize = self.quantize if quantize is None else quantize
+ log_sigma = sigma.log()
+ dists = log_sigma - self.log_sigmas[:, None]
+ if quantize:
+ return dists.abs().argmin(dim=0).view(sigma.shape)
+ low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+ low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
+ w = (low - log_sigma) / (low - high)
+ w = w.clamp(0, 1)
+ t = (1 - w) * low_idx + w * high_idx
+ return t.view(sigma.shape)
+
+ def t_to_sigma(self, t):
+ t = t.float()
+ low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
+ log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
+ return log_sigma.exp()
+
+
+class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
+ """A wrapper for discrete schedule DDPM models that output eps (the predicted
+ noise)."""
+
+ def __init__(self, model, alphas_cumprod, quantize):
+ super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
+ self.inner_model = model
+ self.sigma_data = 1.
+
+ def get_scalings(self, sigma):
+ c_out = -sigma
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ return c_out, c_in
+
+ def get_eps(self, *args, **kwargs):
+ return self.inner_model(*args, **kwargs)
+
+ def loss(self, input, noise, sigma, **kwargs):
+ c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ noised_input = input + noise * utils.append_dims(sigma, input.ndim)
+ eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
+ return (eps - noise).pow(2).flatten(1).mean(1)
+
+ def forward(self, input, sigma, **kwargs):
+ c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
+ return input + eps * c_out
+
+
+class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
+ """A wrapper for OpenAI diffusion models."""
+
+ def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
+ alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
+ super().__init__(model, alphas_cumprod, quantize=quantize)
+ self.has_learned_sigmas = has_learned_sigmas
+
+ def get_eps(self, *args, **kwargs):
+ model_output = self.inner_model(*args, **kwargs)
+ if self.has_learned_sigmas:
+ return model_output.chunk(2, dim=1)[0]
+ return model_output
+
+
+class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
+ """A wrapper for CompVis diffusion models."""
+
+ def __init__(self, model, quantize=False, device='cpu'):
+ super().__init__(model, model.alphas_cumprod, quantize=quantize)
+
+ def get_eps(self, *args, **kwargs):
+ return self.inner_model.apply_model(*args, **kwargs)
+
+
+class DiscreteVDDPMDenoiser(DiscreteSchedule):
+ """A wrapper for discrete schedule DDPM models that output v."""
+
+ def __init__(self, model, alphas_cumprod, quantize):
+ super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
+ self.inner_model = model
+ self.sigma_data = 1.
+
+ def get_scalings(self, sigma):
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
+ c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ return c_skip, c_out, c_in
+
+ def get_v(self, *args, **kwargs):
+ return self.inner_model(*args, **kwargs)
+
+ def loss(self, input, noise, sigma, **kwargs):
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ noised_input = input + noise * utils.append_dims(sigma, input.ndim)
+ model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
+ target = (input - c_skip * noised_input) / c_out
+ return (model_output - target).pow(2).flatten(1).mean(1)
+
+ def forward(self, input, sigma, **kwargs):
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
+
+
+class CompVisVDenoiser(DiscreteVDDPMDenoiser):
+ """A wrapper for CompVis diffusion models that output v."""
+
+ def __init__(self, model, quantize=False, device='cpu'):
+ super().__init__(model, model.alphas_cumprod, quantize=quantize)
+
+ def get_v(self, x, t, cond, **kwargs):
+ return self.inner_model.apply_model(x, t, cond)
diff --git a/repositories/k-diffusion/k_diffusion/gns.py b/repositories/k-diffusion/k_diffusion/gns.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcb7b8d8a9aeae38a7f961c63f66cca4ef90a9e7
--- /dev/null
+++ b/repositories/k-diffusion/k_diffusion/gns.py
@@ -0,0 +1,99 @@
+import torch
+from torch import nn
+
+
+class DDPGradientStatsHook:
+ def __init__(self, ddp_module):
+ try:
+ ddp_module.register_comm_hook(self, self._hook_fn)
+ except AttributeError:
+ raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
+ self._clear_state()
+
+ def _clear_state(self):
+ self.bucket_sq_norms_small_batch = []
+ self.bucket_sq_norms_large_batch = []
+
+ @staticmethod
+ def _hook_fn(self, bucket):
+ buf = bucket.buffer()
+ self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
+ fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future()
+ def callback(fut):
+ buf = fut.value()[0]
+ self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
+ return buf
+ return fut.then(callback)
+
+ def get_stats(self):
+ sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
+ sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
+ self._clear_state()
+ stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
+ torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
+ return stats[0].item(), stats[1].item()
+
+
+class GradientNoiseScale:
+ """Calculates the gradient noise scale (1 / SNR), or critical batch size,
+ from _An Empirical Model of Large-Batch Training_,
+ https://arxiv.org/abs/1812.06162).
+
+ Args:
+ beta (float): The decay factor for the exponential moving averages used to
+ calculate the gradient noise scale.
+ Default: 0.9998
+ eps (float): Added for numerical stability.
+ Default: 1e-8
+ """
+
+ def __init__(self, beta=0.9998, eps=1e-8):
+ self.beta = beta
+ self.eps = eps
+ self.ema_sq_norm = 0.
+ self.ema_var = 0.
+ self.beta_cumprod = 1.
+ self.gradient_noise_scale = float('nan')
+
+ def state_dict(self):
+ """Returns the state of the object as a :class:`dict`."""
+ return dict(self.__dict__.items())
+
+ def load_state_dict(self, state_dict):
+ """Loads the object's state.
+ Args:
+ state_dict (dict): object state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ self.__dict__.update(state_dict)
+
+ def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch):
+ """Updates the state with a new batch's gradient statistics, and returns the
+ current gradient noise scale.
+
+ Args:
+ sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
+ per sample gradients.
+ sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
+ per sample gradients.
+ n_small_batch (int): The batch size of the individual microbatch or per sample
+ gradients (1 if per sample).
+ n_large_batch (int): The total batch size of the mean of the microbatch or
+ per sample gradients.
+ """
+ est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch)
+ est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch)
+ self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
+ self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
+ self.beta_cumprod *= self.beta
+ self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps)
+ return self.gradient_noise_scale
+
+ def get_gns(self):
+ """Returns the current gradient noise scale."""
+ return self.gradient_noise_scale
+
+ def get_stats(self):
+ """Returns the current (debiased) estimates of the squared mean gradient
+ and gradient variance."""
+ return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod)
diff --git a/repositories/k-diffusion/k_diffusion/layers.py b/repositories/k-diffusion/k_diffusion/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdeba0ad68f584261bd88de608e843a350489544
--- /dev/null
+++ b/repositories/k-diffusion/k_diffusion/layers.py
@@ -0,0 +1,246 @@
+import math
+
+from einops import rearrange, repeat
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from . import utils
+
+# Karras et al. preconditioned denoiser
+
+class Denoiser(nn.Module):
+ """A Karras et al. preconditioner for denoising diffusion models."""
+
+ def __init__(self, inner_model, sigma_data=1.):
+ super().__init__()
+ self.inner_model = inner_model
+ self.sigma_data = sigma_data
+
+ def get_scalings(self, sigma):
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
+ c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+ return c_skip, c_out, c_in
+
+ def loss(self, input, noise, sigma, **kwargs):
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ noised_input = input + noise * utils.append_dims(sigma, input.ndim)
+ model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
+ target = (input - c_skip * noised_input) / c_out
+ return (model_output - target).pow(2).flatten(1).mean(1)
+
+ def forward(self, input, sigma, **kwargs):
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
+
+
+class DenoiserWithVariance(Denoiser):
+ def loss(self, input, noise, sigma, **kwargs):
+ c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ noised_input = input + noise * utils.append_dims(sigma, input.ndim)
+ model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs)
+ logvar = utils.append_dims(logvar, model_output.ndim)
+ target = (input - c_skip * noised_input) / c_out
+ losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2
+ return losses.flatten(1).mean(1)
+
+
+# Residual blocks
+
+class ResidualBlock(nn.Module):
+ def __init__(self, *main, skip=None):
+ super().__init__()
+ self.main = nn.Sequential(*main)
+ self.skip = skip if skip else nn.Identity()
+
+ def forward(self, input):
+ return self.main(input) + self.skip(input)
+
+
+# Noise level (and other) conditioning
+
+class ConditionedModule(nn.Module):
+ pass
+
+
+class UnconditionedModule(ConditionedModule):
+ def __init__(self, module):
+ super().__init__()
+ self.module = module
+
+ def forward(self, input, cond=None):
+ return self.module(input)
+
+
+class ConditionedSequential(nn.Sequential, ConditionedModule):
+ def forward(self, input, cond):
+ for module in self:
+ if isinstance(module, ConditionedModule):
+ input = module(input, cond)
+ else:
+ input = module(input)
+ return input
+
+
+class ConditionedResidualBlock(ConditionedModule):
+ def __init__(self, *main, skip=None):
+ super().__init__()
+ self.main = ConditionedSequential(*main)
+ self.skip = skip if skip else nn.Identity()
+
+ def forward(self, input, cond):
+ skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input)
+ return self.main(input, cond) + skip
+
+
+class AdaGN(ConditionedModule):
+ def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'):
+ super().__init__()
+ self.num_groups = num_groups
+ self.eps = eps
+ self.cond_key = cond_key
+ self.mapper = nn.Linear(feats_in, c_out * 2)
+
+ def forward(self, input, cond):
+ weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1)
+ input = F.group_norm(input, self.num_groups, eps=self.eps)
+ return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1)
+
+
+# Attention
+
+class SelfAttention2d(ConditionedModule):
+ def __init__(self, c_in, n_head, norm, dropout_rate=0.):
+ super().__init__()
+ assert c_in % n_head == 0
+ self.norm_in = norm(c_in)
+ self.n_head = n_head
+ self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
+ self.out_proj = nn.Conv2d(c_in, c_in, 1)
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(self, input, cond):
+ n, c, h, w = input.shape
+ qkv = self.qkv_proj(self.norm_in(input, cond))
+ qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = k.shape[3] ** -0.25
+ att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
+ att = self.dropout(att)
+ y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
+ return input + self.out_proj(y)
+
+
+class CrossAttention2d(ConditionedModule):
+ def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0.,
+ cond_key='cross', cond_key_padding='cross_padding'):
+ super().__init__()
+ assert c_dec % n_head == 0
+ self.cond_key = cond_key
+ self.cond_key_padding = cond_key_padding
+ self.norm_enc = nn.LayerNorm(c_enc)
+ self.norm_dec = norm_dec(c_dec)
+ self.n_head = n_head
+ self.q_proj = nn.Conv2d(c_dec, c_dec, 1)
+ self.kv_proj = nn.Linear(c_enc, c_dec * 2)
+ self.out_proj = nn.Conv2d(c_dec, c_dec, 1)
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def forward(self, input, cond):
+ n, c, h, w = input.shape
+ q = self.q_proj(self.norm_dec(input, cond))
+ q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3)
+ kv = self.kv_proj(self.norm_enc(cond[self.cond_key]))
+ kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2)
+ k, v = kv.chunk(2, dim=1)
+ scale = k.shape[3] ** -0.25
+ att = ((q * scale) @ (k.transpose(2, 3) * scale))
+ att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000
+ att = att.softmax(3)
+ att = self.dropout(att)
+ y = (att @ v).transpose(2, 3)
+ y = y.contiguous().view([n, c, h, w])
+ return input + self.out_proj(y)
+
+
+# Downsampling/upsampling
+
+_kernels = {
+ 'linear':
+ [1 / 8, 3 / 8, 3 / 8, 1 / 8],
+ 'cubic':
+ [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
+ 0.43359375, 0.11328125, -0.03515625, -0.01171875],
+ 'lanczos3':
+ [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
+ -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
+ 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
+ -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
+}
+_kernels['bilinear'] = _kernels['linear']
+_kernels['bicubic'] = _kernels['cubic']
+
+
+class Downsample2d(nn.Module):
+ def __init__(self, kernel='linear', pad_mode='reflect'):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = torch.tensor([_kernels[kernel]])
+ self.pad = kernel_1d.shape[1] // 2 - 1
+ self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
+
+ def forward(self, x):
+ x = F.pad(x, (self.pad,) * 4, self.pad_mode)
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
+ indices = torch.arange(x.shape[1], device=x.device)
+ weight[indices, indices] = self.kernel.to(weight)
+ return F.conv2d(x, weight, stride=2)
+
+
+class Upsample2d(nn.Module):
+ def __init__(self, kernel='linear', pad_mode='reflect'):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = torch.tensor([_kernels[kernel]]) * 2
+ self.pad = kernel_1d.shape[1] // 2 - 1
+ self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
+
+ def forward(self, x):
+ x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
+ indices = torch.arange(x.shape[1], device=x.device)
+ weight[indices, indices] = self.kernel.to(weight)
+ return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
+
+
+# Embeddings
+
+class FourierFeatures(nn.Module):
+ def __init__(self, in_features, out_features, std=1.):
+ super().__init__()
+ assert out_features % 2 == 0
+ self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
+
+ def forward(self, input):
+ f = 2 * math.pi * input @ self.weight.T
+ return torch.cat([f.cos(), f.sin()], dim=-1)
+
+
+# U-Nets
+
+class UNet(ConditionedModule):
+ def __init__(self, d_blocks, u_blocks, skip_stages=0):
+ super().__init__()
+ self.d_blocks = nn.ModuleList(d_blocks)
+ self.u_blocks = nn.ModuleList(u_blocks)
+ self.skip_stages = skip_stages
+
+ def forward(self, input, cond):
+ skips = []
+ for block in self.d_blocks[self.skip_stages:]:
+ input = block(input, cond)
+ skips.append(input)
+ for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
+ input = block(input, cond, skip if i > 0 else None)
+ return input
diff --git a/repositories/k-diffusion/k_diffusion/models/__init__.py b/repositories/k-diffusion/k_diffusion/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..82608ff1de6137b31eeaf8de6814df6a7e35606a
--- /dev/null
+++ b/repositories/k-diffusion/k_diffusion/models/__init__.py
@@ -0,0 +1 @@
+from .image_v1 import ImageDenoiserModelV1
diff --git a/repositories/k-diffusion/k_diffusion/models/image_v1.py b/repositories/k-diffusion/k_diffusion/models/image_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ffd5f2c4d6c9d086107d5fac67452419696c723
--- /dev/null
+++ b/repositories/k-diffusion/k_diffusion/models/image_v1.py
@@ -0,0 +1,156 @@
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .. import layers, utils
+
+
+def orthogonal_(module):
+ nn.init.orthogonal_(module.weight)
+ return module
+
+
+class ResConvBlock(layers.ConditionedResidualBlock):
+ def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.):
+ skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False))
+ super().__init__(
+ layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)),
+ nn.GELU(),
+ nn.Conv2d(c_in, c_mid, 3, padding=1),
+ nn.Dropout2d(dropout_rate, inplace=True),
+ layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)),
+ nn.GELU(),
+ nn.Conv2d(c_mid, c_out, 3, padding=1),
+ nn.Dropout2d(dropout_rate, inplace=True),
+ skip=skip)
+
+
+class DBlock(layers.ConditionedSequential):
+ def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0):
+ modules = [nn.Identity()]
+ for i in range(n_layers):
+ my_c_in = c_in if i == 0 else c_mid
+ my_c_out = c_mid if i < n_layers - 1 else c_out
+ modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
+ if self_attn:
+ norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
+ modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
+ if cross_attn:
+ norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
+ modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
+ super().__init__(*modules)
+ self.set_downsample(downsample)
+
+ def set_downsample(self, downsample):
+ self[0] = layers.Downsample2d() if downsample else nn.Identity()
+ return self
+
+
+class UBlock(layers.ConditionedSequential):
+ def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0):
+ modules = []
+ for i in range(n_layers):
+ my_c_in = c_in if i == 0 else c_mid
+ my_c_out = c_mid if i < n_layers - 1 else c_out
+ modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
+ if self_attn:
+ norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
+ modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
+ if cross_attn:
+ norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
+ modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
+ modules.append(nn.Identity())
+ super().__init__(*modules)
+ self.set_upsample(upsample)
+
+ def forward(self, input, cond, skip=None):
+ if skip is not None:
+ input = torch.cat([input, skip], dim=1)
+ return super().forward(input, cond)
+
+ def set_upsample(self, upsample):
+ self[-1] = layers.Upsample2d() if upsample else nn.Identity()
+ return self
+
+
+class MappingNet(nn.Sequential):
+ def __init__(self, feats_in, feats_out, n_layers=2):
+ layers = []
+ for i in range(n_layers):
+ layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out)))
+ layers.append(nn.GELU())
+ super().__init__(*layers)
+
+
+class ImageDenoiserModelV1(nn.Module):
+ def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False):
+ super().__init__()
+ self.c_in = c_in
+ self.channels = channels
+ self.unet_cond_dim = unet_cond_dim
+ self.patch_size = patch_size
+ self.has_variance = has_variance
+ self.timestep_embed = layers.FourierFeatures(1, feats_in)
+ if mapping_cond_dim > 0:
+ self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
+ self.mapping = MappingNet(feats_in, feats_in)
+ self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1)
+ self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
+ nn.init.zeros_(self.proj_out.weight)
+ nn.init.zeros_(self.proj_out.bias)
+ if cross_cond_dim == 0:
+ cross_attn_depths = [False] * len(self_attn_depths)
+ d_blocks, u_blocks = [], []
+ for i in range(len(depths)):
+ my_c_in = channels[max(0, i - 1)]
+ d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
+ for i in range(len(depths)):
+ my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
+ my_c_out = channels[max(0, i - 1)]
+ u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
+ self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages)
+
+ def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False):
+ c_noise = sigma.log() / 4
+ timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
+ mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond)
+ mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
+ cond = {'cond': mapping_out}
+ if unet_cond is not None:
+ input = torch.cat([input, unet_cond], dim=1)
+ if cross_cond is not None:
+ cond['cross'] = cross_cond
+ cond['cross_padding'] = cross_cond_padding
+ if self.patch_size > 1:
+ input = F.pixel_unshuffle(input, self.patch_size)
+ input = self.proj_in(input)
+ input = self.u_net(input, cond)
+ input = self.proj_out(input)
+ if self.has_variance:
+ input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1)
+ if self.patch_size > 1:
+ input = F.pixel_shuffle(input, self.patch_size)
+ if self.has_variance and return_variance:
+ return input, logvar
+ return input
+
+ def set_skip_stages(self, skip_stages):
+ self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1)
+ self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1)
+ nn.init.zeros_(self.proj_out.weight)
+ nn.init.zeros_(self.proj_out.bias)
+ self.u_net.skip_stages = skip_stages
+ for i, block in enumerate(self.u_net.d_blocks):
+ block.set_downsample(i > skip_stages)
+ for i, block in enumerate(reversed(self.u_net.u_blocks)):
+ block.set_upsample(i > skip_stages)
+ return self
+
+ def set_patch_size(self, patch_size):
+ self.patch_size = patch_size
+ self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1)
+ self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
+ nn.init.zeros_(self.proj_out.weight)
+ nn.init.zeros_(self.proj_out.bias)
diff --git a/repositories/k-diffusion/k_diffusion/sampling.py b/repositories/k-diffusion/k_diffusion/sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..f050f88e43bf5d0073cddbbb9f085f7137835fd1
--- /dev/null
+++ b/repositories/k-diffusion/k_diffusion/sampling.py
@@ -0,0 +1,607 @@
+import math
+
+from scipy import integrate
+import torch
+from torch import nn
+from torchdiffeq import odeint
+import torchsde
+from tqdm.auto import trange, tqdm
+
+from . import utils
+
+
+def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+
+
+def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
+ """Constructs the noise schedule of Karras et al. (2022)."""
+ ramp = torch.linspace(0, 1, n)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return append_zero(sigmas).to(device)
+
+
+def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
+ """Constructs an exponential noise schedule."""
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
+ return append_zero(sigmas)
+
+
+def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
+ """Constructs an polynomial in log sigma noise schedule."""
+ ramp = torch.linspace(1, 0, n, device=device) ** rho
+ sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
+ return append_zero(sigmas)
+
+
+def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
+ """Constructs a continuous VP noise schedule."""
+ t = torch.linspace(1, eps_s, n, device=device)
+ sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
+ return append_zero(sigmas)
+
+
+def to_d(x, sigma, denoised):
+ """Converts a denoiser output to a Karras ODE derivative."""
+ return (x - denoised) / utils.append_dims(sigma, x.ndim)
+
+
+def get_ancestral_step(sigma_from, sigma_to, eta=1.):
+ """Calculates the noise level (sigma_down) to step down to and the amount
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
+ if not eta:
+ return sigma_to, 0.
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
+ return sigma_down, sigma_up
+
+
+def default_noise_sampler(x):
+ return lambda sigma, sigma_next: torch.randn_like(x)
+
+
+class BatchedBrownianTree:
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
+
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
+ t0, t1, self.sign = self.sort(t0, t1)
+ w0 = kwargs.get('w0', torch.zeros_like(x))
+ if seed is None:
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
+ self.batched = True
+ try:
+ assert len(seed) == x.shape[0]
+ w0 = w0[0]
+ except TypeError:
+ seed = [seed]
+ self.batched = False
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
+
+ @staticmethod
+ def sort(a, b):
+ return (a, b, 1) if a < b else (b, a, -1)
+
+ def __call__(self, t0, t1):
+ t0, t1, sign = self.sort(t0, t1)
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
+ return w if self.batched else w[0]
+
+
+class BrownianTreeNoiseSampler:
+ """A noise sampler backed by a torchsde.BrownianTree.
+
+ Args:
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
+ random samples.
+ sigma_min (float): The low end of the valid interval.
+ sigma_max (float): The high end of the valid interval.
+ seed (int or List[int]): The random seed. If a list of seeds is
+ supplied instead of a single integer, then the noise sampler will
+ use one BrownianTree per batch item, each with its own seed.
+ transform (callable): A function that maps sigma to the sampler's
+ internal timestep.
+ """
+
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
+ self.transform = transform
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
+
+ def __call__(self, sigma, sigma_next):
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
+
+
+@torch.no_grad()
+def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+ eps = torch.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
+ denoised = model(x, sigma_hat * s_in, **extra_args)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+ dt = sigmas[i + 1] - sigma_hat
+ # Euler method
+ x = x + d * dt
+ return x
+
+
+@torch.no_grad()
+def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
+ """Ancestral sampling with Euler method steps."""
+ extra_args = {} if extra_args is None else extra_args
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ d = to_d(x, sigmas[i], denoised)
+ # Euler method
+ dt = sigma_down - sigmas[i]
+ x = x + d * dt
+ if sigmas[i + 1] > 0:
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
+ return x
+
+
+@torch.no_grad()
+def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+ eps = torch.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
+ denoised = model(x, sigma_hat * s_in, **extra_args)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+ dt = sigmas[i + 1] - sigma_hat
+ if sigmas[i + 1] == 0:
+ # Euler method
+ x = x + d * dt
+ else:
+ # Heun's method
+ x_2 = x + d * dt
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
+ d_prime = (d + d_2) / 2
+ x = x + d_prime * dt
+ return x
+
+
+@torch.no_grad()
+def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+ eps = torch.randn_like(x) * s_noise
+ sigma_hat = sigmas[i] * (gamma + 1)
+ if gamma > 0:
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
+ denoised = model(x, sigma_hat * s_in, **extra_args)
+ d = to_d(x, sigma_hat, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+ if sigmas[i + 1] == 0:
+ # Euler method
+ dt = sigmas[i + 1] - sigma_hat
+ x = x + d * dt
+ else:
+ # DPM-Solver-2
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
+ dt_1 = sigma_mid - sigma_hat
+ dt_2 = sigmas[i + 1] - sigma_hat
+ x_2 = x + d * dt_1
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
+ x = x + d_2 * dt_2
+ return x
+
+
+@torch.no_grad()
+def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
+ """Ancestral sampling with DPM-Solver second-order steps."""
+ extra_args = {} if extra_args is None else extra_args
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ d = to_d(x, sigmas[i], denoised)
+ if sigma_down == 0:
+ # Euler method
+ dt = sigma_down - sigmas[i]
+ x = x + d * dt
+ else:
+ # DPM-Solver-2
+ sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
+ dt_1 = sigma_mid - sigmas[i]
+ dt_2 = sigma_down - sigmas[i]
+ x_2 = x + d * dt_1
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
+ x = x + d_2 * dt_2
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
+ return x
+
+
+def linear_multistep_coeff(order, t, i, j):
+ if order - 1 > i:
+ raise ValueError(f'Order {order} too high for step {i}')
+ def fn(tau):
+ prod = 1.
+ for k in range(order):
+ if j == k:
+ continue
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
+ return prod
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
+
+
+@torch.no_grad()
+def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ sigmas_cpu = sigmas.detach().cpu().numpy()
+ ds = []
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ d = to_d(x, sigmas[i], denoised)
+ ds.append(d)
+ if len(ds) > order:
+ ds.pop(0)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ cur_order = min(i + 1, order)
+ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
+ return x
+
+
+@torch.no_grad()
+def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ v = torch.randint_like(x, 2) * 2 - 1
+ fevals = 0
+ def ode_fn(sigma, x):
+ nonlocal fevals
+ with torch.enable_grad():
+ x = x[0].detach().requires_grad_()
+ denoised = model(x, sigma * s_in, **extra_args)
+ d = to_d(x, sigma, denoised)
+ fevals += 1
+ grad = torch.autograd.grad((d * v).sum(), x)[0]
+ d_ll = (v * grad).flatten(1).sum(1)
+ return d.detach(), d_ll
+ x_min = x, x.new_zeros([x.shape[0]])
+ t = x.new_tensor([sigma_min, sigma_max])
+ sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
+ latent, delta_ll = sol[0][-1], sol[1][-1]
+ ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
+ return ll_prior + delta_ll, {'fevals': fevals}
+
+
+class PIDStepSizeController:
+ """A PID controller for ODE adaptive step size control."""
+ def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
+ self.h = h
+ self.b1 = (pcoeff + icoeff + dcoeff) / order
+ self.b2 = -(pcoeff + 2 * dcoeff) / order
+ self.b3 = dcoeff / order
+ self.accept_safety = accept_safety
+ self.eps = eps
+ self.errs = []
+
+ def limiter(self, x):
+ return 1 + math.atan(x - 1)
+
+ def propose_step(self, error):
+ inv_error = 1 / (float(error) + self.eps)
+ if not self.errs:
+ self.errs = [inv_error, inv_error, inv_error]
+ self.errs[0] = inv_error
+ factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
+ factor = self.limiter(factor)
+ accept = factor >= self.accept_safety
+ if accept:
+ self.errs[2] = self.errs[1]
+ self.errs[1] = self.errs[0]
+ self.h *= factor
+ return accept
+
+
+class DPMSolver(nn.Module):
+ """DPM-Solver. See https://arxiv.org/abs/2206.00927."""
+
+ def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
+ super().__init__()
+ self.model = model
+ self.extra_args = {} if extra_args is None else extra_args
+ self.eps_callback = eps_callback
+ self.info_callback = info_callback
+
+ def t(self, sigma):
+ return -sigma.log()
+
+ def sigma(self, t):
+ return t.neg().exp()
+
+ def eps(self, eps_cache, key, x, t, *args, **kwargs):
+ if key in eps_cache:
+ return eps_cache[key], eps_cache
+ sigma = self.sigma(t) * x.new_ones([x.shape[0]])
+ eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
+ if self.eps_callback is not None:
+ self.eps_callback()
+ return eps, {key: eps, **eps_cache}
+
+ def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
+ eps_cache = {} if eps_cache is None else eps_cache
+ h = t_next - t
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
+ x_1 = x - self.sigma(t_next) * h.expm1() * eps
+ return x_1, eps_cache
+
+ def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
+ eps_cache = {} if eps_cache is None else eps_cache
+ h = t_next - t
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
+ s1 = t + r1 * h
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
+ x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
+ return x_2, eps_cache
+
+ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
+ eps_cache = {} if eps_cache is None else eps_cache
+ h = t_next - t
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
+ s1 = t + r1 * h
+ s2 = t + r2 * h
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
+ u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
+ eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
+ x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
+ return x_3, eps_cache
+
+ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ if not t_end > t_start and eta:
+ raise ValueError('eta must be 0 for reverse sampling')
+
+ m = math.floor(nfe / 3) + 1
+ ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
+
+ if nfe % 3 == 0:
+ orders = [3] * (m - 2) + [2, 1]
+ else:
+ orders = [3] * (m - 1) + [nfe % 3]
+
+ for i in range(len(orders)):
+ eps_cache = {}
+ t, t_next = ts[i], ts[i + 1]
+ if eta:
+ sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
+ t_next_ = torch.minimum(t_end, self.t(sd))
+ su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
+ else:
+ t_next_, su = t_next, 0.
+
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
+ denoised = x - self.sigma(t) * eps
+ if self.info_callback is not None:
+ self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
+
+ if orders[i] == 1:
+ x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
+ elif orders[i] == 2:
+ x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
+ else:
+ x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
+
+ x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
+
+ return x
+
+ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ if order not in {2, 3}:
+ raise ValueError('order should be 2 or 3')
+ forward = t_end > t_start
+ if not forward and eta:
+ raise ValueError('eta must be 0 for reverse sampling')
+ h_init = abs(h_init) * (1 if forward else -1)
+ atol = torch.tensor(atol)
+ rtol = torch.tensor(rtol)
+ s = t_start
+ x_prev = x
+ accept = True
+ pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
+ info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
+
+ while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
+ eps_cache = {}
+ t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
+ if eta:
+ sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
+ t_ = torch.minimum(t_end, self.t(sd))
+ su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
+ else:
+ t_, su = t, 0.
+
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
+ denoised = x - self.sigma(s) * eps
+
+ if order == 2:
+ x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
+ x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
+ else:
+ x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
+ x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
+ delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
+ error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
+ accept = pid.propose_step(error)
+ if accept:
+ x_prev = x_low
+ x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
+ s = t
+ info['n_accept'] += 1
+ else:
+ info['n_reject'] += 1
+ info['nfe'] += order
+ info['steps'] += 1
+
+ if self.info_callback is not None:
+ self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
+
+ return x, info
+
+
+@torch.no_grad()
+def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
+ """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
+ if sigma_min <= 0 or sigma_max <= 0:
+ raise ValueError('sigma_min and sigma_max must not be 0')
+ with tqdm(total=n, disable=disable) as pbar:
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
+ if callback is not None:
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
+ return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
+
+
+@torch.no_grad()
+def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
+ """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
+ if sigma_min <= 0 or sigma_max <= 0:
+ raise ValueError('sigma_min and sigma_max must not be 0')
+ with tqdm(disable=disable) as pbar:
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
+ if callback is not None:
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
+ x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
+ if return_info:
+ return x, info
+ return x
+
+
+@torch.no_grad()
+def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
+ extra_args = {} if extra_args is None else extra_args
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+ sigma_fn = lambda t: t.neg().exp()
+ t_fn = lambda sigma: sigma.log().neg()
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ if sigma_down == 0:
+ # Euler method
+ d = to_d(x, sigmas[i], denoised)
+ dt = sigma_down - sigmas[i]
+ x = x + d * dt
+ else:
+ # DPM-Solver++(2S)
+ t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
+ r = 1 / 2
+ h = t_next - t
+ s = t + r * h
+ x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
+ # Noise addition
+ if sigmas[i + 1] > 0:
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
+ return x
+
+
+@torch.no_grad()
+def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
+ """DPM-Solver++ (stochastic)."""
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ sigma_fn = lambda t: t.neg().exp()
+ t_fn = lambda sigma: sigma.log().neg()
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ if sigmas[i + 1] == 0:
+ # Euler method
+ d = to_d(x, sigmas[i], denoised)
+ dt = sigmas[i + 1] - sigmas[i]
+ x = x + d * dt
+ else:
+ # DPM-Solver++
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
+ h = t_next - t
+ s = t + h * r
+ fac = 1 / (2 * r)
+
+ # Step 1
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
+ s_ = t_fn(sd)
+ x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
+ x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
+
+ # Step 2
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
+ t_next_ = t_fn(sd)
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
+ x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
+ x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
+ return x
+
+
+@torch.no_grad()
+def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
+ """DPM-Solver++(2M)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ sigma_fn = lambda t: t.neg().exp()
+ t_fn = lambda sigma: sigma.log().neg()
+ old_denoised = None
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
+ h = t_next - t
+ if old_denoised is None or sigmas[i + 1] == 0:
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
+ else:
+ h_last = t - t_fn(sigmas[i - 1])
+ r = h_last / h
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
+ old_denoised = denoised
+ return x
diff --git a/repositories/k-diffusion/k_diffusion/utils.py b/repositories/k-diffusion/k_diffusion/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9afedb99276d55d5b923a04ffb62d403c9dfae93
--- /dev/null
+++ b/repositories/k-diffusion/k_diffusion/utils.py
@@ -0,0 +1,329 @@
+from contextlib import contextmanager
+import hashlib
+import math
+from pathlib import Path
+import shutil
+import urllib
+import warnings
+
+from PIL import Image
+import torch
+from torch import nn, optim
+from torch.utils import data
+from torchvision.transforms import functional as TF
+
+
+def from_pil_image(x):
+ """Converts from a PIL image to a tensor."""
+ x = TF.to_tensor(x)
+ if x.ndim == 2:
+ x = x[..., None]
+ return x * 2 - 1
+
+
+def to_pil_image(x):
+ """Converts from a tensor to a PIL image."""
+ if x.ndim == 4:
+ assert x.shape[0] == 1
+ x = x[0]
+ if x.shape[0] == 1:
+ x = x[0]
+ return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
+
+
+def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
+ """Apply passed in transforms for HuggingFace Datasets."""
+ images = [transform(image.convert(mode)) for image in examples[image_key]]
+ return {image_key: images}
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def n_params(module):
+ """Returns the number of trainable parameters in a module."""
+ return sum(p.numel() for p in module.parameters())
+
+
+def download_file(path, url, digest=None):
+ """Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
+ path = Path(path)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ if not path.exists():
+ with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
+ shutil.copyfileobj(response, f)
+ if digest is not None:
+ file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
+ if digest != file_digest:
+ raise OSError(f'hash of {path} (url: {url}) failed to validate')
+ return path
+
+
+@contextmanager
+def train_mode(model, mode=True):
+ """A context manager that places a model into training mode and restores
+ the previous mode on exit."""
+ modes = [module.training for module in model.modules()]
+ try:
+ yield model.train(mode)
+ finally:
+ for i, module in enumerate(model.modules()):
+ module.training = modes[i]
+
+
+def eval_mode(model):
+ """A context manager that places a model into evaluation mode and restores
+ the previous mode on exit."""
+ return train_mode(model, False)
+
+
+@torch.no_grad()
+def ema_update(model, averaged_model, decay):
+ """Incorporates updated model parameters into an exponential moving averaged
+ version of a model. It should be called after each optimizer step."""
+ model_params = dict(model.named_parameters())
+ averaged_params = dict(averaged_model.named_parameters())
+ assert model_params.keys() == averaged_params.keys()
+
+ for name, param in model_params.items():
+ averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
+
+ model_buffers = dict(model.named_buffers())
+ averaged_buffers = dict(averaged_model.named_buffers())
+ assert model_buffers.keys() == averaged_buffers.keys()
+
+ for name, buf in model_buffers.items():
+ averaged_buffers[name].copy_(buf)
+
+
+class EMAWarmup:
+ """Implements an EMA warmup using an inverse decay schedule.
+ If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
+ good values for models you plan to train for a million or more steps (reaches decay
+ factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
+ you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
+ 215.4k steps).
+ Args:
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
+ power (float): Exponential factor of EMA warmup. Default: 1.
+ min_value (float): The minimum EMA decay rate. Default: 0.
+ max_value (float): The maximum EMA decay rate. Default: 1.
+ start_at (int): The epoch to start averaging at. Default: 0.
+ last_epoch (int): The index of last epoch. Default: 0.
+ """
+
+ def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
+ last_epoch=0):
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.min_value = min_value
+ self.max_value = max_value
+ self.start_at = start_at
+ self.last_epoch = last_epoch
+
+ def state_dict(self):
+ """Returns the state of the class as a :class:`dict`."""
+ return dict(self.__dict__.items())
+
+ def load_state_dict(self, state_dict):
+ """Loads the class's state.
+ Args:
+ state_dict (dict): scaler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ self.__dict__.update(state_dict)
+
+ def get_value(self):
+ """Gets the current EMA decay rate."""
+ epoch = max(0, self.last_epoch - self.start_at)
+ value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
+ return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
+
+ def step(self):
+ """Updates the step count."""
+ self.last_epoch += 1
+
+
+class InverseLR(optim.lr_scheduler._LRScheduler):
+ """Implements an inverse decay learning rate schedule with an optional exponential
+ warmup. When last_epoch=-1, sets initial lr as lr.
+ inv_gamma is the number of steps/epochs required for the learning rate to decay to
+ (1 / 2)**power of its original value.
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
+ power (float): Exponential factor of learning rate decay. Default: 1.
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
+ Default: 0.
+ min_lr (float): The minimum learning rate. Default: 0.
+ last_epoch (int): The index of last epoch. Default: -1.
+ verbose (bool): If ``True``, prints a message to stdout for
+ each update. Default: ``False``.
+ """
+
+ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
+ last_epoch=-1, verbose=False):
+ self.inv_gamma = inv_gamma
+ self.power = power
+ if not 0. <= warmup < 1:
+ raise ValueError('Invalid value for warmup')
+ self.warmup = warmup
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch, verbose)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn("To get the last learning rate computed by the scheduler, "
+ "please use `get_last_lr()`.")
+
+ return self._get_closed_form_lr()
+
+ def _get_closed_form_lr(self):
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
+ lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
+ for base_lr in self.base_lrs]
+
+
+class ExponentialLR(optim.lr_scheduler._LRScheduler):
+ """Implements an exponential learning rate schedule with an optional exponential
+ warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
+ continuously by decay (default 0.5) every num_steps steps.
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ num_steps (float): The number of steps to decay the learning rate by decay in.
+ decay (float): The factor by which to decay the learning rate every num_steps
+ steps. Default: 0.5.
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
+ Default: 0.
+ min_lr (float): The minimum learning rate. Default: 0.
+ last_epoch (int): The index of last epoch. Default: -1.
+ verbose (bool): If ``True``, prints a message to stdout for
+ each update. Default: ``False``.
+ """
+
+ def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
+ last_epoch=-1, verbose=False):
+ self.num_steps = num_steps
+ self.decay = decay
+ if not 0. <= warmup < 1:
+ raise ValueError('Invalid value for warmup')
+ self.warmup = warmup
+ self.min_lr = min_lr
+ super().__init__(optimizer, last_epoch, verbose)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn("To get the last learning rate computed by the scheduler, "
+ "please use `get_last_lr()`.")
+
+ return self._get_closed_form_lr()
+
+ def _get_closed_form_lr(self):
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
+ lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
+ for base_lr in self.base_lrs]
+
+
+def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
+ """Draws samples from an lognormal distribution."""
+ return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
+
+
+def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
+ """Draws samples from an optionally truncated log-logistic distribution."""
+ min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
+ max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
+ min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
+ max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
+ u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
+ return u.logit().mul(scale).add(loc).exp().to(dtype)
+
+
+def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
+ """Draws samples from an log-uniform distribution."""
+ min_value = math.log(min_value)
+ max_value = math.log(max_value)
+ return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
+
+
+def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
+ """Draws samples from a truncated v-diffusion training timestep distribution."""
+ min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
+ max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
+ u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
+ return torch.tan(u * math.pi / 2) * sigma_data
+
+
+def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
+ """Draws samples from a split lognormal distribution."""
+ n = torch.randn(shape, device=device, dtype=dtype).abs()
+ u = torch.rand(shape, device=device, dtype=dtype)
+ n_left = n * -scale_1 + loc
+ n_right = n * scale_2 + loc
+ ratio = scale_1 / (scale_1 + scale_2)
+ return torch.where(u < ratio, n_left, n_right).exp()
+
+
+class FolderOfImages(data.Dataset):
+ """Recursively finds all images in a directory. It does not support
+ classes/targets."""
+
+ IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
+
+ def __init__(self, root, transform=None):
+ super().__init__()
+ self.root = Path(root)
+ self.transform = nn.Identity() if transform is None else transform
+ self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
+
+ def __repr__(self):
+ return f'FolderOfImages(root="{self.root}", len: {len(self)})'
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, key):
+ path = self.paths[key]
+ with open(path, 'rb') as f:
+ image = Image.open(f).convert('RGB')
+ image = self.transform(image)
+ return image,
+
+
+class CSVLogger:
+ def __init__(self, filename, columns):
+ self.filename = Path(filename)
+ self.columns = columns
+ if self.filename.exists():
+ self.file = open(self.filename, 'a')
+ else:
+ self.file = open(self.filename, 'w')
+ self.write(*self.columns)
+
+ def write(self, *args):
+ print(*args, sep=',', file=self.file, flush=True)
+
+
+@contextmanager
+def tf32_mode(cudnn=None, matmul=None):
+ """A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
+ cudnn_old = torch.backends.cudnn.allow_tf32
+ matmul_old = torch.backends.cuda.matmul.allow_tf32
+ try:
+ if cudnn is not None:
+ torch.backends.cudnn.allow_tf32 = cudnn
+ if matmul is not None:
+ torch.backends.cuda.matmul.allow_tf32 = matmul
+ yield
+ finally:
+ if cudnn is not None:
+ torch.backends.cudnn.allow_tf32 = cudnn_old
+ if matmul is not None:
+ torch.backends.cuda.matmul.allow_tf32 = matmul_old
diff --git a/repositories/k-diffusion/make_grid.py b/repositories/k-diffusion/make_grid.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c6616843cac1a69fdb94df804822cf07b533543
--- /dev/null
+++ b/repositories/k-diffusion/make_grid.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python3
+
+"""Assembles images into a grid."""
+
+import argparse
+import math
+import sys
+
+from PIL import Image
+
+
+def main():
+ p = argparse.ArgumentParser(description=__doc__)
+ p.add_argument('images', type=str, nargs='+', metavar='image',
+ help='the input images')
+ p.add_argument('--output', '-o', type=str, default='out.png',
+ help='the output image')
+ p.add_argument('--nrow', type=int,
+ help='the number of images per row')
+ args = p.parse_args()
+
+ images = [Image.open(image) for image in args.images]
+ mode = images[0].mode
+ size = images[0].size
+ for image, name in zip(images, args.images):
+ if image.mode != mode:
+ print(f'Error: Image {name} had mode {image.mode}, expected {mode}', file=sys.stderr)
+ sys.exit(1)
+ if image.size != size:
+ print(f'Error: Image {name} had size {image.size}, expected {size}', file=sys.stderr)
+ sys.exit(1)
+
+ n = len(images)
+ x = args.nrow if args.nrow else math.ceil(n**0.5)
+ y = math.ceil(n / x)
+
+ output = Image.new(mode, (size[0] * x, size[1] * y))
+ for i, image in enumerate(images):
+ cur_x, cur_y = i % x, i // x
+ output.paste(image, (size[0] * cur_x, size[1] * cur_y))
+
+ output.save(args.output)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/repositories/k-diffusion/pyproject.toml b/repositories/k-diffusion/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..fed528d4a7a148fd0bf0b0198a6461f8c91b87e9
--- /dev/null
+++ b/repositories/k-diffusion/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools"]
+build-backend = "setuptools.build_meta"
diff --git a/repositories/k-diffusion/requirements.txt b/repositories/k-diffusion/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7d4ea409e83f37e7509ea61dad63492617864650
--- /dev/null
+++ b/repositories/k-diffusion/requirements.txt
@@ -0,0 +1,15 @@
+accelerate
+clean-fid
+clip-anytorch
+einops
+jsonmerge
+kornia
+Pillow
+resize-right
+scikit-image
+scipy
+torch
+torchdiffeq
+torchvision
+tqdm
+wandb
diff --git a/repositories/k-diffusion/sample.py b/repositories/k-diffusion/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..21e0dc3c9ca055f7de73b7df7aa2841025187c18
--- /dev/null
+++ b/repositories/k-diffusion/sample.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+
+"""Samples from k-diffusion models."""
+
+import argparse
+import math
+
+import accelerate
+import torch
+from tqdm import trange, tqdm
+
+import k_diffusion as K
+
+
+def main():
+ p = argparse.ArgumentParser(description=__doc__,
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ p.add_argument('--batch-size', type=int, default=64,
+ help='the batch size')
+ p.add_argument('--checkpoint', type=str, required=True,
+ help='the checkpoint to use')
+ p.add_argument('--config', type=str, required=True,
+ help='the model config')
+ p.add_argument('-n', type=int, default=64,
+ help='the number of images to sample')
+ p.add_argument('--prefix', type=str, default='out',
+ help='the output prefix')
+ p.add_argument('--steps', type=int, default=50,
+ help='the number of denoising steps')
+ args = p.parse_args()
+
+ config = K.config.load_config(open(args.config))
+ model_config = config['model']
+ # TODO: allow non-square input sizes
+ assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1]
+ size = model_config['input_size']
+
+ accelerator = accelerate.Accelerator()
+ device = accelerator.device
+ print('Using device:', device, flush=True)
+
+ inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device)
+ inner_model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model_ema'])
+ accelerator.print('Parameters:', K.utils.n_params(inner_model))
+ model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data'])
+
+ sigma_min = model_config['sigma_min']
+ sigma_max = model_config['sigma_max']
+
+ @torch.no_grad()
+ @K.utils.eval_mode(model)
+ def run():
+ if accelerator.is_local_main_process:
+ tqdm.write('Sampling...')
+ sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
+ def sample_fn(n):
+ x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
+ x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process)
+ return x_0
+ x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
+ if accelerator.is_main_process:
+ for i, out in enumerate(x_0):
+ filename = f'{args.prefix}_{i:05}.png'
+ K.utils.to_pil_image(out).save(filename)
+
+ try:
+ run()
+ except KeyboardInterrupt:
+ pass
+
+
+if __name__ == '__main__':
+ main()
diff --git a/repositories/k-diffusion/sample_clip_guided.py b/repositories/k-diffusion/sample_clip_guided.py
new file mode 100644
index 0000000000000000000000000000000000000000..592350196fbbac8479563be5be9e138248d94c86
--- /dev/null
+++ b/repositories/k-diffusion/sample_clip_guided.py
@@ -0,0 +1,131 @@
+#!/usr/bin/env python3
+
+"""CLIP guided sampling from k-diffusion models."""
+
+import argparse
+import math
+
+import accelerate
+import clip
+from kornia import augmentation as KA
+from resize_right import resize
+import torch
+from torch.nn import functional as F
+from torchvision import transforms
+from tqdm import trange, tqdm
+
+import k_diffusion as K
+
+
+def spherical_dist_loss(x, y):
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+
+def make_cond_model_fn(model, cond_fn):
+ def model_fn(x, sigma, **kwargs):
+ with torch.enable_grad():
+ x = x.detach().requires_grad_()
+ denoised = model(x, sigma, **kwargs)
+ cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
+ cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
+ return cond_denoised
+ return model_fn
+
+
+def make_static_thresh_model_fn(model, value=1.):
+ def model_fn(x, sigma, **kwargs):
+ return model(x, sigma, **kwargs).clamp(-value, value)
+ return model_fn
+
+
+def main():
+ p = argparse.ArgumentParser(description=__doc__,
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ p.add_argument('prompt', type=str,
+ default='the prompt to use')
+ p.add_argument('--batch-size', type=int, default=16,
+ help='the batch size')
+ p.add_argument('--checkpoint', type=str, required=True,
+ help='the checkpoint to use')
+ p.add_argument('--clip-guidance-scale', '-cgs', type=float, default=500.,
+ help='the CLIP guidance scale')
+ p.add_argument('--clip-model', type=str, default='ViT-B/16', choices=clip.available_models(),
+ help='the CLIP model to use')
+ p.add_argument('--config', type=str, required=True,
+ help='the model config')
+ p.add_argument('-n', type=int, default=64,
+ help='the number of images to sample')
+ p.add_argument('--prefix', type=str, default='out',
+ help='the output prefix')
+ p.add_argument('--steps', type=int, default=100,
+ help='the number of denoising steps')
+ args = p.parse_args()
+
+ config = K.config.load_config(open(args.config))
+ model_config = config['model']
+ # TODO: allow non-square input sizes
+ assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1]
+ size = model_config['input_size']
+
+ accelerator = accelerate.Accelerator()
+ device = accelerator.device
+ print('Using device:', device, flush=True)
+
+ inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device)
+ inner_model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model_ema'])
+ accelerator.print('Parameters:', K.utils.n_params(inner_model))
+ model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data'])
+
+ sigma_min = model_config['sigma_min']
+ sigma_max = model_config['sigma_max']
+
+ clip_model = clip.load(args.clip_model, device=device)[0].eval().requires_grad_(False)
+ clip_normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711))
+ clip_size = (clip_model.visual.input_resolution, clip_model.visual.input_resolution)
+ aug = KA.RandomAffine(0, (1/14, 1/14), p=1, padding_mode='border')
+
+ def get_image_embed(x):
+ if x.shape[2:4] != clip_size:
+ x = resize(x, out_shape=clip_size, pad_mode='reflect')
+ x = clip_normalize(x)
+ x = clip_model.encode_image(x).float()
+ return F.normalize(x)
+
+ target_embed = F.normalize(clip_model.encode_text(clip.tokenize(args.prompt, truncate=True).to(device)).float())
+
+ def cond_fn(x, t, denoised):
+ image_embed = get_image_embed(aug(denoised.add(1).div(2)))
+ loss = spherical_dist_loss(image_embed, target_embed).sum() * args.clip_guidance_scale
+ grad = -torch.autograd.grad(loss, x)[0]
+ return grad
+
+ model_fn = make_cond_model_fn(model, cond_fn)
+ model_fn = make_static_thresh_model_fn(model_fn)
+
+ @torch.no_grad()
+ @K.utils.eval_mode(model)
+ def run():
+ if accelerator.is_local_main_process:
+ tqdm.write('Sampling...')
+ sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
+ def sample_fn(n):
+ x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigmas[0]
+ x_0 = K.sampling.sample_dpmpp_2s_ancestral(model_fn, x, sigmas, eta=1., disable=not accelerator.is_local_main_process)
+ return x_0
+ x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
+ if accelerator.is_main_process:
+ for i, out in enumerate(x_0):
+ filename = f'{args.prefix}_{i:05}.png'
+ K.utils.to_pil_image(out).save(filename)
+
+ try:
+ run()
+ except KeyboardInterrupt:
+ pass
+
+
+if __name__ == '__main__':
+ main()
diff --git a/repositories/k-diffusion/setup.cfg b/repositories/k-diffusion/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..e1ac99d3112d104b6167a1a5de6712a431e98703
--- /dev/null
+++ b/repositories/k-diffusion/setup.cfg
@@ -0,0 +1,30 @@
+[metadata]
+name = k-diffusion
+version = 0.0.12
+author = Katherine Crowson
+author_email = crowsonkb@gmail.com
+url = https://github.com/crowsonkb/k-diffusion
+description = Karras et al. (2022) diffusion models for PyTorch
+long_description = file: README.md
+long_description_content_type = text/markdown
+license = MIT
+
+[options]
+packages = find:
+install_requires =
+ accelerate
+ clean-fid
+ clip-anytorch
+ einops
+ jsonmerge
+ kornia
+ Pillow
+ resize-right
+ scikit-image
+ scipy
+ torch
+ torchdiffeq
+ torchsde
+ torchvision
+ tqdm
+ wandb
diff --git a/repositories/k-diffusion/setup.py b/repositories/k-diffusion/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ae4555937eb30e6632281a2326726826a41fe88
--- /dev/null
+++ b/repositories/k-diffusion/setup.py
@@ -0,0 +1,5 @@
+from setuptools import setup
+
+
+if __name__ == '__main__':
+ setup()
diff --git a/repositories/k-diffusion/train.py b/repositories/k-diffusion/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbfbeb9c6f523a9b8ce03cb353f81146b6f95051
--- /dev/null
+++ b/repositories/k-diffusion/train.py
@@ -0,0 +1,353 @@
+#!/usr/bin/env python3
+
+"""Trains Karras et al. (2022) diffusion models."""
+
+import argparse
+from copy import deepcopy
+from functools import partial
+import math
+import json
+from pathlib import Path
+
+import accelerate
+import torch
+from torch import nn, optim
+from torch import multiprocessing as mp
+from torch.utils import data
+from torchvision import datasets, transforms, utils
+from tqdm.auto import trange, tqdm
+
+import k_diffusion as K
+
+
+def main():
+ p = argparse.ArgumentParser(description=__doc__,
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ p.add_argument('--batch-size', type=int, default=64,
+ help='the batch size')
+ p.add_argument('--config', type=str, required=True,
+ help='the configuration file')
+ p.add_argument('--demo-every', type=int, default=500,
+ help='save a demo grid every this many steps')
+ p.add_argument('--evaluate-every', type=int, default=10000,
+ help='save a demo grid every this many steps')
+ p.add_argument('--evaluate-n', type=int, default=2000,
+ help='the number of samples to draw to evaluate')
+ p.add_argument('--gns', action='store_true',
+ help='measure the gradient noise scale (DDP only)')
+ p.add_argument('--grad-accum-steps', type=int, default=1,
+ help='the number of gradient accumulation steps')
+ p.add_argument('--grow', type=str,
+ help='the checkpoint to grow from')
+ p.add_argument('--grow-config', type=str,
+ help='the configuration file of the model to grow from')
+ p.add_argument('--lr', type=float,
+ help='the learning rate')
+ p.add_argument('--name', type=str, default='model',
+ help='the name of the run')
+ p.add_argument('--num-workers', type=int, default=8,
+ help='the number of data loader workers')
+ p.add_argument('--resume', type=str,
+ help='the checkpoint to resume from')
+ p.add_argument('--sample-n', type=int, default=64,
+ help='the number of images to sample for demo grids')
+ p.add_argument('--save-every', type=int, default=10000,
+ help='save every this many steps')
+ p.add_argument('--seed', type=int,
+ help='the random seed')
+ p.add_argument('--start-method', type=str, default='spawn',
+ choices=['fork', 'forkserver', 'spawn'],
+ help='the multiprocessing start method')
+ p.add_argument('--wandb-entity', type=str,
+ help='the wandb entity name')
+ p.add_argument('--wandb-group', type=str,
+ help='the wandb group name')
+ p.add_argument('--wandb-project', type=str,
+ help='the wandb project name (specify this to enable wandb)')
+ p.add_argument('--wandb-save-model', action='store_true',
+ help='save model to wandb')
+ args = p.parse_args()
+
+ mp.set_start_method(args.start_method)
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ config = K.config.load_config(open(args.config))
+ model_config = config['model']
+ dataset_config = config['dataset']
+ opt_config = config['optimizer']
+ sched_config = config['lr_sched']
+ ema_sched_config = config['ema_sched']
+
+ # TODO: allow non-square input sizes
+ assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1]
+ size = model_config['input_size']
+
+ ddp_kwargs = accelerate.DistributedDataParallelKwargs(find_unused_parameters=model_config['skip_stages'] > 0)
+ accelerator = accelerate.Accelerator(kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps=args.grad_accum_steps)
+ device = accelerator.device
+ print(f'Process {accelerator.process_index} using device: {device}', flush=True)
+
+ if args.seed is not None:
+ seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes], generator=torch.Generator().manual_seed(args.seed))
+ torch.manual_seed(seeds[accelerator.process_index])
+
+ inner_model = K.config.make_model(config)
+ if accelerator.is_main_process:
+ print('Parameters:', K.utils.n_params(inner_model))
+
+ # If logging to wandb, initialize the run
+ use_wandb = accelerator.is_main_process and args.wandb_project
+ if use_wandb:
+ import wandb
+ log_config = vars(args)
+ log_config['config'] = config
+ log_config['parameters'] = K.utils.n_params(inner_model)
+ wandb.init(project=args.wandb_project, entity=args.wandb_entity, group=args.wandb_group, config=log_config, save_code=True)
+
+ if opt_config['type'] == 'adamw':
+ opt = optim.AdamW(inner_model.parameters(),
+ lr=opt_config['lr'] if args.lr is None else args.lr,
+ betas=tuple(opt_config['betas']),
+ eps=opt_config['eps'],
+ weight_decay=opt_config['weight_decay'])
+ elif opt_config['type'] == 'sgd':
+ opt = optim.SGD(inner_model.parameters(),
+ lr=opt_config['lr'] if args.lr is None else args.lr,
+ momentum=opt_config.get('momentum', 0.),
+ nesterov=opt_config.get('nesterov', False),
+ weight_decay=opt_config.get('weight_decay', 0.))
+ else:
+ raise ValueError('Invalid optimizer type')
+
+ if sched_config['type'] == 'inverse':
+ sched = K.utils.InverseLR(opt,
+ inv_gamma=sched_config['inv_gamma'],
+ power=sched_config['power'],
+ warmup=sched_config['warmup'])
+ elif sched_config['type'] == 'exponential':
+ sched = K.utils.ExponentialLR(opt,
+ num_steps=sched_config['num_steps'],
+ decay=sched_config['decay'],
+ warmup=sched_config['warmup'])
+ else:
+ raise ValueError('Invalid schedule type')
+
+ assert ema_sched_config['type'] == 'inverse'
+ ema_sched = K.utils.EMAWarmup(power=ema_sched_config['power'],
+ max_value=ema_sched_config['max_value'])
+
+ tf = transforms.Compose([
+ transforms.Resize(size[0], interpolation=transforms.InterpolationMode.LANCZOS),
+ transforms.CenterCrop(size[0]),
+ K.augmentation.KarrasAugmentationPipeline(model_config['augment_prob']),
+ ])
+
+ if dataset_config['type'] == 'imagefolder':
+ train_set = K.utils.FolderOfImages(dataset_config['location'], transform=tf)
+ elif dataset_config['type'] == 'cifar10':
+ train_set = datasets.CIFAR10(dataset_config['location'], train=True, download=True, transform=tf)
+ elif dataset_config['type'] == 'mnist':
+ train_set = datasets.MNIST(dataset_config['location'], train=True, download=True, transform=tf)
+ elif dataset_config['type'] == 'huggingface':
+ from datasets import load_dataset
+ train_set = load_dataset(dataset_config['location'])
+ train_set.set_transform(partial(K.utils.hf_datasets_augs_helper, transform=tf, image_key=dataset_config['image_key']))
+ train_set = train_set['train']
+ else:
+ raise ValueError('Invalid dataset type')
+
+ if accelerator.is_main_process:
+ try:
+ print('Number of items in dataset:', len(train_set))
+ except TypeError:
+ pass
+
+ image_key = dataset_config.get('image_key', 0)
+
+ train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, drop_last=True,
+ num_workers=args.num_workers, persistent_workers=True)
+
+ if args.grow:
+ if not args.grow_config:
+ raise ValueError('--grow requires --grow-config')
+ ckpt = torch.load(args.grow, map_location='cpu')
+ old_config = K.config.load_config(open(args.grow_config))
+ old_inner_model = K.config.make_model(old_config)
+ old_inner_model.load_state_dict(ckpt['model_ema'])
+ if old_config['model']['skip_stages'] != model_config['skip_stages']:
+ old_inner_model.set_skip_stages(model_config['skip_stages'])
+ if old_config['model']['patch_size'] != model_config['patch_size']:
+ old_inner_model.set_patch_size(model_config['patch_size'])
+ inner_model.load_state_dict(old_inner_model.state_dict())
+ del ckpt, old_inner_model
+
+ inner_model, opt, train_dl = accelerator.prepare(inner_model, opt, train_dl)
+ if use_wandb:
+ wandb.watch(inner_model)
+ if args.gns:
+ gns_stats_hook = K.gns.DDPGradientStatsHook(inner_model)
+ gns_stats = K.gns.GradientNoiseScale()
+ else:
+ gns_stats = None
+ sigma_min = model_config['sigma_min']
+ sigma_max = model_config['sigma_max']
+ sample_density = K.config.make_sample_density(model_config)
+
+ model = K.config.make_denoiser_wrapper(config)(inner_model)
+ model_ema = deepcopy(model)
+
+ state_path = Path(f'{args.name}_state.json')
+
+ if state_path.exists() or args.resume:
+ if args.resume:
+ ckpt_path = args.resume
+ if not args.resume:
+ state = json.load(open(state_path))
+ ckpt_path = state['latest_checkpoint']
+ if accelerator.is_main_process:
+ print(f'Resuming from {ckpt_path}...')
+ ckpt = torch.load(ckpt_path, map_location='cpu')
+ accelerator.unwrap_model(model.inner_model).load_state_dict(ckpt['model'])
+ accelerator.unwrap_model(model_ema.inner_model).load_state_dict(ckpt['model_ema'])
+ opt.load_state_dict(ckpt['opt'])
+ sched.load_state_dict(ckpt['sched'])
+ ema_sched.load_state_dict(ckpt['ema_sched'])
+ epoch = ckpt['epoch'] + 1
+ step = ckpt['step'] + 1
+ if args.gns and ckpt.get('gns_stats', None) is not None:
+ gns_stats.load_state_dict(ckpt['gns_stats'])
+
+ del ckpt
+ else:
+ epoch = 0
+ step = 0
+
+ evaluate_enabled = args.evaluate_every > 0 and args.evaluate_n > 0
+ if evaluate_enabled:
+ extractor = K.evaluation.InceptionV3FeatureExtractor(device=device)
+ train_iter = iter(train_dl)
+ if accelerator.is_main_process:
+ print('Computing features for reals...')
+ reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, args.evaluate_n, args.batch_size)
+ if accelerator.is_main_process:
+ metrics_log = K.utils.CSVLogger(f'{args.name}_metrics.csv', ['step', 'fid', 'kid'])
+ del train_iter
+
+ @torch.no_grad()
+ @K.utils.eval_mode(model_ema)
+ def demo():
+ if accelerator.is_main_process:
+ tqdm.write('Sampling...')
+ filename = f'{args.name}_demo_{step:08}.png'
+ n_per_proc = math.ceil(args.sample_n / accelerator.num_processes)
+ x = torch.randn([n_per_proc, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
+ sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device)
+ x_0 = K.sampling.sample_dpmpp_2m(model_ema, x, sigmas, disable=not accelerator.is_main_process)
+ x_0 = accelerator.gather(x_0)[:args.sample_n]
+ if accelerator.is_main_process:
+ grid = utils.make_grid(x_0, nrow=math.ceil(args.sample_n ** 0.5), padding=0)
+ K.utils.to_pil_image(grid).save(filename)
+ if use_wandb:
+ wandb.log({'demo_grid': wandb.Image(filename)}, step=step)
+
+ @torch.no_grad()
+ @K.utils.eval_mode(model_ema)
+ def evaluate():
+ if not evaluate_enabled:
+ return
+ if accelerator.is_main_process:
+ tqdm.write('Evaluating...')
+ sigmas = K.sampling.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device)
+ def sample_fn(n):
+ x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
+ x_0 = K.sampling.sample_dpmpp_2m(model_ema, x, sigmas, disable=True)
+ return x_0
+ fakes_features = K.evaluation.compute_features(accelerator, sample_fn, extractor, args.evaluate_n, args.batch_size)
+ if accelerator.is_main_process:
+ fid = K.evaluation.fid(fakes_features, reals_features)
+ kid = K.evaluation.kid(fakes_features, reals_features)
+ print(f'FID: {fid.item():g}, KID: {kid.item():g}')
+ if accelerator.is_main_process:
+ metrics_log.write(step, fid.item(), kid.item())
+ if use_wandb:
+ wandb.log({'FID': fid.item(), 'KID': kid.item()}, step=step)
+
+ def save():
+ accelerator.wait_for_everyone()
+ filename = f'{args.name}_{step:08}.pth'
+ if accelerator.is_main_process:
+ tqdm.write(f'Saving to {filename}...')
+ obj = {
+ 'model': accelerator.unwrap_model(model.inner_model).state_dict(),
+ 'model_ema': accelerator.unwrap_model(model_ema.inner_model).state_dict(),
+ 'opt': opt.state_dict(),
+ 'sched': sched.state_dict(),
+ 'ema_sched': ema_sched.state_dict(),
+ 'epoch': epoch,
+ 'step': step,
+ 'gns_stats': gns_stats.state_dict() if gns_stats is not None else None,
+ }
+ accelerator.save(obj, filename)
+ if accelerator.is_main_process:
+ state_obj = {'latest_checkpoint': filename}
+ json.dump(state_obj, open(state_path, 'w'))
+ if args.wandb_save_model and use_wandb:
+ wandb.save(filename)
+
+ try:
+ while True:
+ for batch in tqdm(train_dl, disable=not accelerator.is_main_process):
+ with accelerator.accumulate(model):
+ reals, _, aug_cond = batch[image_key]
+ noise = torch.randn_like(reals)
+ sigma = sample_density([reals.shape[0]], device=device)
+ losses = model.loss(reals, noise, sigma, aug_cond=aug_cond)
+ losses_all = accelerator.gather(losses)
+ loss = losses_all.mean()
+ accelerator.backward(losses.mean())
+ if args.gns:
+ sq_norm_small_batch, sq_norm_large_batch = gns_stats_hook.get_stats()
+ gns_stats.update(sq_norm_small_batch, sq_norm_large_batch, reals.shape[0], reals.shape[0] * accelerator.num_processes)
+ opt.step()
+ sched.step()
+ opt.zero_grad()
+ if accelerator.sync_gradients:
+ ema_decay = ema_sched.get_value()
+ K.utils.ema_update(model, model_ema, ema_decay)
+ ema_sched.step()
+
+ if accelerator.is_main_process:
+ if step % 25 == 0:
+ if args.gns:
+ tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}, gns: {gns_stats.get_gns():g}')
+ else:
+ tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}')
+
+ if use_wandb:
+ log_dict = {
+ 'epoch': epoch,
+ 'loss': loss.item(),
+ 'lr': sched.get_last_lr()[0],
+ 'ema_decay': ema_decay,
+ }
+ if args.gns:
+ log_dict['gradient_noise_scale'] = gns_stats.get_gns()
+ wandb.log(log_dict, step=step)
+
+ if step % args.demo_every == 0:
+ demo()
+
+ if evaluate_enabled and step > 0 and step % args.evaluate_every == 0:
+ evaluate()
+
+ if step > 0 and step % args.save_every == 0:
+ save()
+
+ step += 1
+ epoch += 1
+ except KeyboardInterrupt:
+ pass
+
+
+if __name__ == '__main__':
+ main()
diff --git a/repositories/stable-diffusion-stability-ai/LICENSE b/repositories/stable-diffusion-stability-ai/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..58a49c99b2b9151af5e1fee0dbd20307671f47ab
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Stability AI
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/repositories/stable-diffusion-stability-ai/LICENSE-MODEL b/repositories/stable-diffusion-stability-ai/LICENSE-MODEL
new file mode 100644
index 0000000000000000000000000000000000000000..9684533d88e7d853a55cabf6caa2f1e4a3e6fdc4
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/LICENSE-MODEL
@@ -0,0 +1,84 @@
+Copyright (c) 2022 Stability AI and contributors
+
+CreativeML Open RAIL++-M License
+dated November 24, 2022
+
+Section I: PREAMBLE
+
+Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
+
+Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
+
+In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
+
+Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
+
+This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
+
+NOW THEREFORE, You and Licensor agree as follows:
+
+1. Definitions
+
+- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
+- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
+- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
+- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
+- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
+- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
+- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
+- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
+- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
+- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
+- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
+- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
+
+Section II: INTELLECTUAL PROPERTY RIGHTS
+
+Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
+
+2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
+3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
+
+Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
+
+4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
+Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
+You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
+You must cause any modified files to carry prominent notices stating that You changed the files;
+You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
+You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
+5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
+6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
+
+Section IV: OTHER PROVISIONS
+
+7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License.
+8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
+9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
+10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
+11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
+12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
+
+END OF TERMS AND CONDITIONS
+
+
+
+
+Attachment A
+
+Use Restrictions
+
+You agree not to use the Model or Derivatives of the Model:
+
+- In any way that violates any applicable national, federal, state, local or international law or regulation;
+- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
+- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
+- To generate or disseminate personal identifiable information that can be used to harm an individual;
+- To defame, disparage or otherwise harass others;
+- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
+- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
+- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
+- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
+- To provide medical advice and medical results interpretation;
+- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
+
diff --git a/repositories/stable-diffusion-stability-ai/README.md b/repositories/stable-diffusion-stability-ai/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..061c15bd792e935067d3437f9e30b86b122e6e02
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/README.md
@@ -0,0 +1,245 @@
+# Stable Diffusion 2.0
+![t2i](assets/stable-samples/txt2img/768/merged-0006.png)
+![t2i](assets/stable-samples/txt2img/768/merged-0002.png)
+![t2i](assets/stable-samples/txt2img/768/merged-0005.png)
+
+This repository contains [Stable Diffusion](https://github.com/CompVis/stable-diffusion) models trained from scratch and will be continuously updated with
+new checkpoints. The following list provides an overview of all currently available models. More coming soon.
+## News
+**November 2022**
+- New stable diffusion model (_Stable Diffusion 2.0-v_) at 768x768 resolution. Same number of parameters in the U-Net as 1.5, but uses [OpenCLIP-ViT/H](https://github.com/mlfoundations/open_clip) as the text encoder and is trained from scratch. _SD 2.0-v_ is a so-called [v-prediction](https://arxiv.org/abs/2202.00512) model.
+- The above model is finetuned from _SD 2.0-base_, which was trained as a standard noise-prediction model on 512x512 images and is also made available.
+- Added a [x4 upscaling latent text-guided diffusion model](#image-upscaling-with-stable-diffusion).
+- New [depth-guided stable diffusion model](#depth-conditional-stable-diffusion), finetuned from _SD 2.0-base_. The model is conditioned on monocular depth estimates inferred via [MiDaS](https://github.com/isl-org/MiDaS) and can be used for structure-preserving img2img and shape-conditional synthesis.
+
+ ![d2i](assets/stable-samples/depth2img/depth2img01.png)
+- A [text-guided inpainting model](#image-inpainting-with-stable-diffusion), finetuned from SD _2.0-base_.
+
+We follow the [original repository](https://github.com/CompVis/stable-diffusion) and provide basic inference scripts to sample from the models.
+
+________________
+*The original Stable Diffusion model was created in a collaboration with [CompVis](https://arxiv.org/abs/2202.00512) and [RunwayML](https://runwayml.com/) and builds upon the work:*
+
+[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://ommer-lab.com/research/latent-diffusion-models/)
+[Robin Rombach](https://github.com/rromb)\*,
+[Andreas Blattmann](https://github.com/ablattmann)\*,
+[Dominik Lorenz](https://github.com/qp-qp)\,
+[Patrick Esser](https://github.com/pesser),
+[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
+_[CVPR '22 Oral](https://openaccess.thecvf.com/content/CVPR2022/html/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.html) |
+[GitHub](https://github.com/CompVis/latent-diffusion) | [arXiv](https://arxiv.org/abs/2112.10752) | [Project page](https://ommer-lab.com/research/latent-diffusion-models/)_
+
+and [many others](#shout-outs).
+
+Stable Diffusion is a latent text-to-image diffusion model.
+________________________________
+
+## Requirements
+
+You can update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
+
+```
+conda install pytorch==1.12.1 torchvision==0.13.1 -c pytorch
+pip install transformers==4.19.2 diffusers invisible-watermark
+pip install -e .
+```
+#### xformers efficient attention
+For more efficiency and speed on GPUs,
+we highly recommended installing the [xformers](https://github.com/facebookresearch/xformers)
+library.
+
+Tested on A100 with CUDA 11.4.
+Installation needs a somewhat recent version of nvcc and gcc/g++, obtain those, e.g., via
+```commandline
+export CUDA_HOME=/usr/local/cuda-11.4
+conda install -c nvidia/label/cuda-11.4.0 cuda-nvcc
+conda install -c conda-forge gcc
+conda install -c conda-forge gxx_linux-64=9.5.0
+```
+
+Then, run the following (compiling takes up to 30 min).
+
+```commandline
+cd ..
+git clone https://github.com/facebookresearch/xformers.git
+cd xformers
+git submodule update --init --recursive
+pip install -r requirements.txt
+pip install -e .
+cd ../stablediffusion
+```
+Upon successful installation, the code will automatically default to [memory efficient attention](https://github.com/facebookresearch/xformers)
+for the self- and cross-attention layers in the U-Net and autoencoder.
+
+## General Disclaimer
+Stable Diffusion models are general text-to-image diffusion models and therefore mirror biases and (mis-)conceptions that are present
+in their training data. Although efforts were made to reduce the inclusion of explicit pornographic material, **we do not recommend using the provided weights for services or products without additional safety mechanisms and considerations.
+The weights are research artifacts and should be treated as such.**
+Details on the training procedure and data, as well as the intended use of the model can be found in the corresponding [model card](https://huggingface.co/stabilityai/stable-diffusion-2).
+The weights are available via [the StabilityAI organization at Hugging Face](https://huggingface.co/StabilityAI) under the [CreativeML Open RAIL++-M License](LICENSE-MODEL).
+
+
+
+## Stable Diffusion v2.0
+
+Stable Diffusion v2.0 refers to a specific configuration of the model
+architecture that uses a downsampling-factor 8 autoencoder with an 865M UNet
+and OpenCLIP ViT-H/14 text encoder for the diffusion model. The _SD 2.0-v_ model produces 768x768 px outputs.
+
+Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
+5.0, 6.0, 7.0, 8.0) and 50 DDIM sampling steps show the relative improvements of the checkpoints:
+
+![sd evaluation results](assets/model-variants.jpg)
+
+
+
+### Text-to-Image
+![txt2img-stable2](assets/stable-samples/txt2img/merged-0003.png)
+![txt2img-stable2](assets/stable-samples/txt2img/merged-0001.png)
+
+Stable Diffusion 2.0 is a latent diffusion model conditioned on the penultimate text embeddings of a CLIP ViT-H/14 text encoder.
+We provide a [reference script for sampling](#reference-sampling-script).
+#### Reference Sampling Script
+
+This script incorporates an [invisible watermarking](https://github.com/ShieldMnt/invisible-watermark) of the outputs, to help viewers [identify the images as machine-generated](scripts/tests/test_watermark.py).
+We provide the configs for the _SD2.0-v_ (768px) and _SD2.0-base_ (512px) model.
+
+First, download the weights for [_SD2.0-v_](https://huggingface.co/stabilityai/stable-diffusion-2) and [_SD2.0-base_](https://huggingface.co/stabilityai/stable-diffusion-2-base).
+
+To sample from the _SD2.0-v_ model, run the following:
+
+```
+python scripts/txt2img.py --prompt "a professional photograph of an astronaut riding a horse" --ckpt --config configs/stable-diffusion/v2-inference-v.yaml --H 768 --W 768
+```
+or try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/stabilityai/stable-diffusion).
+
+To sample from the base model, use
+```
+python scripts/txt2img.py --prompt "a professional photograph of an astronaut riding a horse" --ckpt --config
+```
+
+By default, this uses the [DDIM sampler](https://arxiv.org/abs/2010.02502), and renders images of size 768x768 (which it was trained on) in 50 steps.
+Empirically, the v-models can be sampled with higher guidance scales.
+
+Note: The inference config for all model versions is designed to be used with EMA-only checkpoints.
+For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
+non-EMA to EMA weights.
+
+### Image Modification with Stable Diffusion
+
+![depth2img-stable2](assets/stable-samples/depth2img/merged-0000.png)
+#### Depth-Conditional Stable Diffusion
+
+To augment the well-established [img2img](https://github.com/CompVis/stable-diffusion#image-modification-with-stable-diffusion) functionality of Stable Diffusion, we provide a _shape-preserving_ stable diffusion model.
+
+
+Note that the original method for image modification introduces significant semantic changes w.r.t. the initial image.
+If that is not desired, download our [depth-conditional stable diffusion](https://huggingface.co/stabilityai/stable-diffusion-2-depth) model and the `dpt_hybrid` MiDaS [model weights](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt), place the latter in a folder `midas_models` and sample via
+```
+python scripts/gradio/depth2img.py configs/stable-diffusion/v2-midas-inference.yaml
+```
+
+or
+
+```
+streamlit run scripts/streamlit/depth2img.py configs/stable-diffusion/v2-midas-inference.yaml
+```
+
+This method can be used on the samples of the base model itself.
+For example, take [this sample](assets/stable-samples/depth2img/old_man.png) generated by an anonymous discord user.
+Using the [gradio](https://gradio.app) or [streamlit](https://streamlit.io/) script `depth2img.py`, the MiDaS model first infers a monocular depth estimate given this input,
+and the diffusion model is then conditioned on the (relative) depth output.
+
+
+ depth2image
+
+
+
+This model is particularly useful for a photorealistic style; see the [examples](assets/stable-samples/depth2img).
+For a maximum strength of 1.0, the model removes all pixel-based information and only relies on the text prompt and the inferred monocular depth estimate.
+
+![depth2img-stable3](assets/stable-samples/depth2img/merged-0005.png)
+
+#### Classic Img2Img
+
+For running the "classic" img2img, use
+```
+python scripts/img2img.py --prompt "A fantasy landscape, trending on artstation" --init-img --strength 0.8 --ckpt
+```
+and adapt the checkpoint and config paths accordingly.
+
+### Image Upscaling with Stable Diffusion
+![upscaling-x4](assets/stable-samples/upscaling/merged-dog.png)
+After [downloading the weights](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler), run
+```
+python scripts/gradio/superresolution.py configs/stable-diffusion/x4-upscaling.yaml
+```
+
+or
+
+```
+streamlit run scripts/streamlit/superresolution.py -- configs/stable-diffusion/x4-upscaling.yaml
+```
+
+for a Gradio or Streamlit demo of the text-guided x4 superresolution model.
+This model can be used both on real inputs and on synthesized examples. For the latter, we recommend setting a higher
+`noise_level`, e.g. `noise_level=100`.
+
+### Image Inpainting with Stable Diffusion
+
+![inpainting-stable2](assets/stable-inpainting/merged-leopards.png)
+
+[Download the SD 2.0-inpainting checkpoint](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) and run
+
+```
+python scripts/gradio/inpainting.py configs/stable-diffusion/v2-inpainting-inference.yaml
+```
+
+or
+
+```
+streamlit run scripts/streamlit/inpainting.py -- configs/stable-diffusion/v2-inpainting-inference.yaml
+```
+
+for a Gradio or Streamlit demo of the inpainting model.
+This scripts adds invisible watermarking to the demo in the [RunwayML](https://github.com/runwayml/stable-diffusion/blob/main/scripts/inpaint_st.py) repository, but both should work interchangeably with the checkpoints/configs.
+
+
+
+## Shout-Outs
+- Thanks to [Hugging Face](https://huggingface.co/) and in particular [Apolinário](https://github.com/apolinario) for support with our model releases!
+- Stable Diffusion would not be possible without [LAION](https://laion.ai/) and their efforts to create open, large-scale datasets.
+- The [DeepFloyd team](https://twitter.com/deepfloydai) at Stability AI, for creating the subset of [LAION-5B](https://laion.ai/blog/laion-5b/) dataset used to train the model.
+- Stable Diffusion 2.0 uses [OpenCLIP](https://laion.ai/blog/large-openclip/), trained by [Romain Beaumont](https://github.com/rom1504).
+- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
+and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
+Thanks for open-sourcing!
+- [CompVis](https://github.com/CompVis/stable-diffusion) initial stable diffusion release
+- [Patrick](https://github.com/pesser)'s [implementation](https://github.com/runwayml/stable-diffusion/blob/main/scripts/inpaint_st.py) of the streamlit demo for inpainting.
+- `img2img` is an application of [SDEdit](https://arxiv.org/abs/2108.01073) by [Chenlin Meng](https://cs.stanford.edu/~chenlin/) from the [Stanford AI Lab](https://cs.stanford.edu/~ermon/website/).
+- [Kat's implementation]((https://github.com/CompVis/latent-diffusion/pull/51)) of the [PLMS](https://arxiv.org/abs/2202.09778) sampler, and [more](https://github.com/crowsonkb/k-diffusion).
+- [DPMSolver](https://arxiv.org/abs/2206.00927) [integration](https://github.com/CompVis/stable-diffusion/pull/440) by [Cheng Lu](https://github.com/LuChengTHU).
+- Facebook's [xformers](https://github.com/facebookresearch/xformers) for efficient attention computation.
+- [MiDaS](https://github.com/isl-org/MiDaS) for monocular depth estimation.
+
+
+## License
+
+The code in this repository is released under the MIT License.
+
+The weights are available via [the StabilityAI organization at Hugging Face](https://huggingface.co/StabilityAI), and released under the [CreativeML Open RAIL++-M License](LICENSE-MODEL) License.
+
+## BibTeX
+
+```
+@misc{rombach2021highresolution,
+ title={High-Resolution Image Synthesis with Latent Diffusion Models},
+ author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
+ year={2021},
+ eprint={2112.10752},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
+
+
diff --git a/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/v2-inference-v.yaml b/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/v2-inference-v.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8ec8dfbfefe94ae8522c93017668fea78d580acf
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/v2-inference-v.yaml
@@ -0,0 +1,68 @@
+model:
+ base_learning_rate: 1.0e-4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ parameterization: "v"
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False # we set this to false because this is an inference only config
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ use_fp16: True
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64 # need to fix for flash-attn
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ #attn_type: "vanilla-xformers"
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
diff --git a/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/v2-inference.yaml b/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/v2-inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..152c4f3c2b36c3b246a9cb10eb8166134b0d2e1c
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/v2-inference.yaml
@@ -0,0 +1,67 @@
+model:
+ base_learning_rate: 1.0e-4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False # we set this to false because this is an inference only config
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ use_fp16: True
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64 # need to fix for flash-attn
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ #attn_type: "vanilla-xformers"
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
diff --git a/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/v2-inpainting-inference.yaml b/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/v2-inpainting-inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..32a9471d71b828c51bcbbabfe34c5f6c8282c803
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/v2-inpainting-inference.yaml
@@ -0,0 +1,158 @@
+model:
+ base_learning_rate: 5.0e-05
+ target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false
+ conditioning_key: hybrid
+ scale_factor: 0.18215
+ monitor: val/loss_simple_ema
+ finetune_keys: null
+ use_ema: False
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ image_size: 32 # unused
+ in_channels: 9
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64 # need to fix for flash-attn
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ #attn_type: "vanilla-xformers"
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: [ ]
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
+
+
+data:
+ target: ldm.data.laion.WebDataModuleFromConfig
+ params:
+ tar_base: null # for concat as in LAION-A
+ p_unsafe_threshold: 0.1
+ filter_word_list: "data/filters.yaml"
+ max_pwatermark: 0.45
+ batch_size: 8
+ num_workers: 6
+ multinode: True
+ min_size: 512
+ train:
+ shards:
+ - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
+ - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
+ - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
+ - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
+ - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
+ shuffle: 10000
+ image_key: jpg
+ image_transforms:
+ - target: torchvision.transforms.Resize
+ params:
+ size: 512
+ interpolation: 3
+ - target: torchvision.transforms.RandomCrop
+ params:
+ size: 512
+ postprocess:
+ target: ldm.data.laion.AddMask
+ params:
+ mode: "512train-large"
+ p_drop: 0.25
+ # NOTE use enough shards to avoid empty validation loops in workers
+ validation:
+ shards:
+ - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
+ shuffle: 0
+ image_key: jpg
+ image_transforms:
+ - target: torchvision.transforms.Resize
+ params:
+ size: 512
+ interpolation: 3
+ - target: torchvision.transforms.CenterCrop
+ params:
+ size: 512
+ postprocess:
+ target: ldm.data.laion.AddMask
+ params:
+ mode: "512train-large"
+ p_drop: 0.25
+
+lightning:
+ find_unused_parameters: True
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 5000
+
+ callbacks:
+ metrics_over_trainsteps_checkpoint:
+ params:
+ every_n_train_steps: 10000
+
+ image_logger:
+ target: main.ImageLogger
+ params:
+ enable_autocast: False
+ disabled: False
+ batch_frequency: 1000
+ max_images: 4
+ increase_log_steps: False
+ log_first_step: False
+ log_images_kwargs:
+ use_ema_scope: False
+ inpaint: False
+ plot_progressive_rows: False
+ plot_diffusion_rows: False
+ N: 4
+ unconditional_guidance_scale: 5.0
+ unconditional_guidance_label: [""]
+ ddim_steps: 50 # todo check these out for depth2img,
+ ddim_eta: 0.0 # todo check these out for depth2img,
+
+ trainer:
+ benchmark: True
+ val_check_interval: 5000000
+ num_sanity_val_steps: 0
+ accumulate_grad_batches: 1
diff --git a/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/v2-midas-inference.yaml b/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/v2-midas-inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f20c30f618b81091e31c2c4cf15325fa38638af4
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/v2-midas-inference.yaml
@@ -0,0 +1,74 @@
+model:
+ base_learning_rate: 5.0e-07
+ target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false
+ conditioning_key: hybrid
+ scale_factor: 0.18215
+ monitor: val/loss_simple_ema
+ finetune_keys: null
+ use_ema: False
+
+ depth_stage_config:
+ target: ldm.modules.midas.api.MiDaSInference
+ params:
+ model_type: "dpt_hybrid"
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ image_size: 32 # unused
+ in_channels: 5
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64 # need to fix for flash-attn
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ #attn_type: "vanilla-xformers"
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: [ ]
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
+
+
diff --git a/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/x4-upscaling.yaml b/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/x4-upscaling.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2db0964af699f86d1891c761710a2d53f59b842c
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/configs/stable-diffusion/x4-upscaling.yaml
@@ -0,0 +1,76 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
+ params:
+ parameterization: "v"
+ low_scale_key: "lr"
+ linear_start: 0.0001
+ linear_end: 0.02
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 128
+ channels: 4
+ cond_stage_trainable: false
+ conditioning_key: "hybrid-adm"
+ monitor: val/loss_simple_ema
+ scale_factor: 0.08333
+ use_ema: False
+
+ low_scale_config:
+ target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
+ params:
+ noise_schedule_config: # image space
+ linear_start: 0.0001
+ linear_end: 0.02
+ max_noise_level: 350
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
+ image_size: 128
+ in_channels: 7
+ out_channels: 4
+ model_channels: 256
+ attention_resolutions: [ 2,4,8]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 2, 4]
+ disable_self_attentions: [True, True, True, False]
+ disable_middle_self_attn: False
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+ use_linear_in_transformer: True
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ ddconfig:
+ # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though)
+ double_z: True
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
+ num_res_blocks: 2
+ attn_resolutions: [ ]
+ dropout: 0.0
+
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
+
diff --git a/repositories/stable-diffusion-stability-ai/environment.yaml b/repositories/stable-diffusion-stability-ai/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4687b309b60ae2d6040fcb3f90a380cf6fb11b21
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/environment.yaml
@@ -0,0 +1,29 @@
+name: ldm
+channels:
+ - pytorch
+ - defaults
+dependencies:
+ - python=3.8.5
+ - pip=20.3
+ - cudatoolkit=11.3
+ - pytorch=1.12.1
+ - torchvision=0.13.1
+ - numpy=1.23.1
+ - pip:
+ - albumentations==1.3.0
+ - opencv-python==4.6.0.66
+ - imageio==2.9.0
+ - imageio-ffmpeg==0.4.2
+ - pytorch-lightning==1.4.2
+ - omegaconf==2.1.1
+ - test-tube>=0.7.5
+ - streamlit==1.12.1
+ - einops==0.3.0
+ - transformers==4.19.2
+ - webdataset==0.2.5
+ - kornia==0.6
+ - open_clip_torch==2.0.2
+ - invisible-watermark>=0.1.5
+ - streamlit-drawable-canvas==0.8.0
+ - torchmetrics==0.6.0
+ - -e .
diff --git a/repositories/stable-diffusion-stability-ai/ldm/data/__init__.py b/repositories/stable-diffusion-stability-ai/ldm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/stable-diffusion-stability-ai/ldm/data/util.py b/repositories/stable-diffusion-stability-ai/ldm/data/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b60ceb2349e3bd7900ff325740e2022d2903b1c
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/data/util.py
@@ -0,0 +1,24 @@
+import torch
+
+from ldm.modules.midas.api import load_midas_transform
+
+
+class AddMiDaS(object):
+ def __init__(self, model_type):
+ super().__init__()
+ self.transform = load_midas_transform(model_type)
+
+ def pt2np(self, x):
+ x = ((x + 1.0) * .5).detach().cpu().numpy()
+ return x
+
+ def np2pt(self, x):
+ x = torch.from_numpy(x) * 2 - 1.
+ return x
+
+ def __call__(self, sample):
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
+ x = self.pt2np(sample['jpg'])
+ x = self.transform({"image": x})["image"]
+ sample['midas_in'] = x
+ return sample
\ No newline at end of file
diff --git a/repositories/stable-diffusion-stability-ai/ldm/models/autoencoder.py b/repositories/stable-diffusion-stability-ai/ldm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d122549995ce2cd64092c81a58419ed4a15a02fd
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/models/autoencoder.py
@@ -0,0 +1,219 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ldm.util import instantiate_from_config
+from ldm.modules.ema import LitEma
+
+
+class AutoencoderKL(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ ema_decay=None,
+ learn_logvar=False
+ ):
+ super().__init__()
+ self.learn_logvar = learn_logvar
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+
+ self.use_ema = ema_decay is not None
+ if self.use_ema:
+ self.ema_decay = ema_decay
+ assert 0. < ema_decay < 1.
+ self.model_ema = LitEma(self, decay=ema_decay)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, postfix=""):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
+ if self.learn_logvar:
+ print(f"{self.__class__.__name__}: Learning logvar")
+ ae_params_list.append(self.loss.logvar)
+ opt_ae = torch.optim.Adam(ae_params_list,
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ if log_ema or self.use_ema:
+ with self.ema_scope():
+ xrec_ema, posterior_ema = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec_ema.shape[1] > 3
+ xrec_ema = self.to_rgb(xrec_ema)
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
+ log["reconstructions_ema"] = xrec_ema
+ log["inputs"] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
diff --git a/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/__init__.py b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddim.py b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..27ead0ea914c64c747b64e690662899fb3801144
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddim.py
@@ -0,0 +1,336 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ ucg_schedule=ucg_schedule
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+ ucg_schedule=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ if ucg_schedule is not None:
+ assert len(ucg_schedule) == len(time_range)
+ unconditional_guidance_scale = ucg_schedule[i]
+
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ model_output = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [torch.cat([
+ unconditional_conditioning[k][i],
+ c[k][i]]) for i in range(len(c[k]))]
+ else:
+ c_in[k] = torch.cat([
+ unconditional_conditioning[k],
+ c[k]])
+ elif isinstance(c, list):
+ c_in = list()
+ assert isinstance(unconditional_conditioning, list)
+ for i in range(len(c)):
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", 'not implemented'
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+ if dynamic_threshold is not None:
+ raise NotImplementedError()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+
+ assert t_enc <= num_reference_steps
+ num_steps = t_enc
+
+ if use_original_steps:
+ alphas_next = self.alphas_cumprod[:num_steps]
+ alphas = self.alphas_cumprod_prev[:num_steps]
+ else:
+ alphas_next = self.ddim_alphas[:num_steps]
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+ x_next = x0
+ intermediates = []
+ inter_steps = []
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+ if unconditional_guidance_scale == 1.:
+ noise_pred = self.model.apply_model(x_next, t, c)
+ else:
+ assert unconditional_conditioning is not None
+ e_t_uncond, noise_pred = torch.chunk(
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+ torch.cat((unconditional_conditioning, c))), 2)
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+ weighted_noise_pred = alphas_next[i].sqrt() * (
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+ x_next = xt_weighted + weighted_noise_pred
+ if return_intermediates and i % (
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ elif return_intermediates and i >= num_steps - 2:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ if callback: callback(i)
+
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+ if return_intermediates:
+ out.update({'intermediates': intermediates})
+ return x_next, out
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False, callback=None):
+
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ if callback: callback(i)
+ return x_dec
\ No newline at end of file
diff --git a/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddpm.py b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bbdd0264b1c27c5d4376d64f2e571f9b0c9437a
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,1795 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager, nullcontext
+from functools import partial
+import itertools
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.distributed import rank_zero_only
+from omegaconf import ListConfig
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+__conditioning_keys__ = {'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ make_it_fit=False,
+ ucg_training=None,
+ reset_ema=False,
+ reset_num_ema_updates=False,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ self.make_it_fit = make_it_fit
+ if reset_ema: assert exists(ckpt_path)
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+ if reset_ema:
+ assert self.use_ema
+ print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+ self.ucg_training = ucg_training or dict()
+ if self.ucg_training:
+ self.ucg_prng = np.random.RandomState()
+
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+ else:
+ raise NotImplementedError("mu not supported")
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ @torch.no_grad()
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ if self.make_it_fit:
+ n_params = len([name for name, _ in
+ itertools.chain(self.named_parameters(),
+ self.named_buffers())])
+ for name, param in tqdm(
+ itertools.chain(self.named_parameters(),
+ self.named_buffers()),
+ desc="Fitting old weights to new weights",
+ total=n_params
+ ):
+ if not name in sd:
+ continue
+ old_shape = sd[name].shape
+ new_shape = param.shape
+ assert len(old_shape) == len(new_shape)
+ if len(new_shape) > 2:
+ # we only modify first two axes
+ assert new_shape[2:] == old_shape[2:]
+ # assumes first axis corresponds to output dim
+ if not new_shape == old_shape:
+ new_param = param.clone()
+ old_param = sd[name]
+ if len(new_shape) == 1:
+ for i in range(new_param.shape[0]):
+ new_param[i] = old_param[i % old_shape[0]]
+ elif len(new_shape) >= 2:
+ for i in range(new_param.shape[0]):
+ for j in range(new_param.shape[1]):
+ new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
+
+ n_used_old = torch.ones(old_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_old[j % old_shape[1]] += 1
+ n_used_new = torch.zeros(new_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_new[j] = n_used_old[j % old_shape[1]]
+
+ n_used_new = n_used_new[None, :]
+ while len(n_used_new.shape) < len(new_shape):
+ n_used_new = n_used_new.unsqueeze(-1)
+ new_param /= n_used_new
+
+ sd[name] = new_param
+
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys:\n {missing}")
+ if len(unexpected) > 0:
+ print(f"\nUnexpected Keys:\n {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def predict_start_from_z_and_v(self, x_t, t, v):
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = 'train' if self.training else 'val'
+
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f'{log_prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ for k in self.ucg_training:
+ p = self.ucg_training[k]["p"]
+ val = self.ucg_training[k]["val"]
+ if val is None:
+ val = ""
+ for i in range(len(batch[k])):
+ if self.ucg_prng.choice(2, p=[1 - p, p]):
+ batch[k][i] = val
+
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(loss_dict, prog_bar=True,
+ logger=True, on_step=True, on_epoch=True)
+
+ self.log("global_step", self.global_step,
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ force_null_conditioning=False,
+ *args, **kwargs):
+ self.force_null_conditioning = force_null_conditioning
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ reset_ema = kwargs.pop("reset_ema", False)
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+ if reset_ema:
+ assert self.use_ema
+ print(
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+ @torch.no_grad()
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None, return_x=False):
+ x = super().get_input(batch, k)
+ if bs is not None:
+ x = x[:bs]
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+ if self.model.conditioning_key is not None and not self.force_null_conditioning:
+ if cond_key is None:
+ cond_key = self.cond_stage_key
+ if cond_key != self.first_stage_key:
+ if cond_key in ['caption', 'coordinates_bbox', "txt"]:
+ xc = batch[cond_key]
+ elif cond_key in ['class_label', 'cls']:
+ xc = batch
+ else:
+ xc = super().get_input(batch, cond_key).to(self.device)
+ else:
+ xc = x
+ if not self.cond_stage_trainable or force_c_encode:
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ c = self.get_learned_conditioning(xc.to(self.device))
+ else:
+ c = xc
+ if bs is not None:
+ c = c[:bs]
+
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ ckey = __conditioning_keys__[self.model.conditioning_key]
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+ else:
+ c = None
+ xc = None
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
+ out = [z, c]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_x:
+ out.extend([x])
+ if return_original_cond:
+ out.append(xc)
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+ return self.first_stage_model.decode(z)
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ return self.first_stage_model.encode(x)
+
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c)
+ return loss
+
+ def forward(self, x, c, *args, **kwargs):
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ if self.cond_stage_trainable:
+ c = self.get_learned_conditioning(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+ return self.p_losses(x, c, t, *args, **kwargs)
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+ if isinstance(cond, dict):
+ # hybrid case, cond is expected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_losses(self, x_start, cond, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+ logvar_t = self.logvar[t].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None, **kwargs):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0)
+
+ @torch.no_grad()
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
+ shape, cond, verbose=False, **kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+ return_intermediates=True, **kwargs)
+
+ return samples, intermediates
+
+ @torch.no_grad()
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
+ if null_label is not None:
+ xc = null_label
+ if isinstance(xc, ListConfig):
+ xc = list(xc)
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ if hasattr(xc, "to"):
+ xc = xc.to(self.device)
+ c = self.get_learned_conditioning(xc)
+ else:
+ if self.cond_stage_key in ["class_label", "cls"]:
+ xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
+ return self.get_learned_conditioning(xc)
+ else:
+ raise NotImplementedError("todo")
+ if isinstance(c, list): # in case the encoder gives us a list
+ for i in range(len(c)):
+ c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
+ else:
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
+ return c
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', "cls"]:
+ try:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ except KeyError:
+ # probably no "human_label" in batch
+ pass
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ # also display when quantizing x0 while sampling
+ with ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ quantize_denoised=True)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if unconditional_guidance_scale > 1.0:
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ if self.model.conditioning_key == "crossattn-adm":
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with ema_scope("Plotting Inpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ mask = 1. - mask
+ with ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ if self.use_scheduler:
+ assert 'target' in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
+
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == 'crossattn':
+ if not self.sequential_cross_attn:
+ cc = torch.cat(c_crossattn, 1)
+ else:
+ cc = c_crossattn
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif self.conditioning_key == 'hybrid-adm':
+ assert c_adm is not None
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
+ elif self.conditioning_key == 'crossattn-adm':
+ assert c_adm is not None
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc, y=c_adm)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+
+ return out
+
+
+class LatentUpscaleDiffusion(LatentDiffusion):
+ def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+ assert not self.cond_stage_trainable
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+ self.noise_level_key = noise_level_key
+
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+ if not log_mode:
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+ else:
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+ x_low = batch[self.low_scale_key][:bs]
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
+ zx, noise_level = self.low_scale_model(x_low)
+ if self.noise_level_key is not None:
+ # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
+ raise NotImplementedError('TODO')
+
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
+ if log_mode:
+ # TODO: maybe disable if too expensive
+ x_low_rec = self.low_scale_model.decode(zx)
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
+ log_mode=True)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ log["x_lr"] = x_low
+ log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', 'cls']:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ # TODO explore better "unconditional" choices for the other keys
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
+ uc = dict()
+ for k in c:
+ if k == "c_crossattn":
+ assert isinstance(c[k], list) and len(c[k]) == 1
+ uc[k] = [uc_tmp]
+ elif k == "c_adm": # todo: only run with text-based guidance?
+ assert isinstance(c[k], torch.Tensor)
+ #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
+ uc[k] = c[k]
+ elif isinstance(c[k], list):
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
+ else:
+ uc[k] = c[k]
+
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ return log
+
+
+class LatentFinetuneDiffusion(LatentDiffusion):
+ """
+ Basis for different finetunas, such as inpainting or depth2image
+ To disable finetuning mode, set finetune_keys to None
+ """
+
+ def __init__(self,
+ concat_keys: tuple,
+ finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
+ "model_ema.diffusion_modelinput_blocks00weight"
+ ),
+ keep_finetune_dims=4,
+ # if model was trained without concat mode before and we would like to keep these channels
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
+ c_concat_log_end=None,
+ *args, **kwargs
+ ):
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", list())
+ super().__init__(*args, **kwargs)
+ self.finetune_keys = finetune_keys
+ self.concat_keys = concat_keys
+ self.keep_dims = keep_finetune_dims
+ self.c_concat_log_start = c_concat_log_start
+ self.c_concat_log_end = c_concat_log_end
+ if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
+ if exists(ckpt_path):
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+
+ # make it explicit, finetune by including extra input channels
+ if exists(self.finetune_keys) and k in self.finetune_keys:
+ new_entry = None
+ for name, param in self.named_parameters():
+ if name in self.finetune_keys:
+ print(
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
+ new_entry = torch.zeros_like(param) # zero init
+ assert exists(new_entry), 'did not find matching parameter to modify'
+ new_entry[:, :self.keep_dims, ...] = sd[k]
+ sd[k] = new_entry
+
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', 'cls']:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
+ log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ uc_cat = c_cat
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc_full,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ return log
+
+
+class LatentInpaintDiffusion(LatentFinetuneDiffusion):
+ """
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
+ e.g. mask as concat and text via cross-attn.
+ To disable finetuning mode, set finetune_keys to None
+ """
+
+ def __init__(self,
+ concat_keys=("mask", "masked_image"),
+ masked_image_key="masked_image",
+ *args, **kwargs
+ ):
+ super().__init__(concat_keys, *args, **kwargs)
+ self.masked_image_key = masked_image_key
+ assert self.masked_image_key in concat_keys
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ bchw = z.shape
+ if ck != self.masked_image_key:
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+ else:
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
+ log["masked_image"] = rearrange(args[0]["masked_image"],
+ 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+ return log
+
+
+class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
+ """
+ condition on monocular depth estimation
+ """
+
+ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
+ self.depth_model = instantiate_from_config(depth_stage_config)
+ self.depth_stage_key = concat_keys[0]
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ assert len(self.concat_keys) == 1
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = batch[ck]
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ cc = self.depth_model(cc)
+ cc = torch.nn.functional.interpolate(
+ cc,
+ size=z.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+ keepdim=True)
+ cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super().log_images(*args, **kwargs)
+ depth = self.depth_model(args[0][self.depth_stage_key])
+ depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
+ torch.amax(depth, dim=[1, 2, 3], keepdim=True)
+ log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
+ return log
+
+
+class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
+ """
+ condition on low-res image (and optionally on some spatial noise augmentation)
+ """
+ def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
+ low_scale_config=None, low_scale_key=None, *args, **kwargs):
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
+ self.reshuffle_patch_size = reshuffle_patch_size
+ self.low_scale_model = None
+ if low_scale_config is not None:
+ print("Initializing a low-scale model")
+ assert exists(low_scale_key)
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ assert len(self.concat_keys) == 1
+ # optionally make spatial noise_level here
+ c_cat = list()
+ noise_level = None
+ for ck in self.concat_keys:
+ cc = batch[ck]
+ cc = rearrange(cc, 'b h w c -> b c h w')
+ if exists(self.reshuffle_patch_size):
+ assert isinstance(self.reshuffle_patch_size, int)
+ cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
+ p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ if exists(self.low_scale_model) and ck == self.low_scale_key:
+ cc, noise_level = self.low_scale_model(cc)
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ if exists(noise_level):
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
+ else:
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super().log_images(*args, **kwargs)
+ log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
+ return log
diff --git a/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/dpm_solver/__init__.py b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/dpm_solver/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/dpm_solver/__init__.py
@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler
\ No newline at end of file
diff --git a/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/dpm_solver/dpm_solver.py b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/dpm_solver/dpm_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/dpm_solver/dpm_solver.py
@@ -0,0 +1,1154 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm import tqdm
+
+
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+ t = self.inverse_lambda(lambda_t)
+ ===============================================================
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+ 1. For discrete-time DPMs:
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+ 2. For continuous-time DPMs:
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+ ===============================================================
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+ Example:
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError(
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+ schedule))
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
+ self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ Delta = self.beta_0 ** 2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+ torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ condition=None,
+ unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+ We support four types of the diffusion model by setting `model_type`:
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+ ===============================================================
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+
+ def cond_grad_fn(x, t_input):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t_continuous] * 2)
+ c_in = torch.cat([unconditional_condition, condition])
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+
+
+class DPM_Solver:
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+ """Construct a DPM-Solver.
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+ Args:
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+ ``
+ def model_fn(x, t_continuous):
+ return noise
+ ``
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+ """
+ self.model = model_fn
+ self.noise_schedule = noise_schedule
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+ Args:
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ N: A `int`. The total number of the spacing of the time steps.
+ device: A torch device.
+ Returns:
+ A pytorch tensor of the time steps, with the shape (N + 1,).
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError(
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+ - If order == 1:
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
+ - If order == 2:
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If order == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+ ============================================
+ Args:
+ order: A `int`. The max order for the solver (2 or 3).
+ steps: A `int`. The total number of function evaluations (NFE).
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ device: A torch device.
+ Returns:
+ orders: A list of the solver order of each step.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [3, ] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [3, ] * (K - 1) + [1]
+ else:
+ orders = [3, ] * (K - 1) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [2, ] * K
+ else:
+ K = steps // 2 + 1
+ orders = [2, ] * (K - 1) + [1]
+ elif order == 1:
+ K = 1
+ orders = [1, ] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == 'logSNR':
+ # To reproduce the results in DPM-Solver paper
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+ else:
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
+ return timesteps_outer, orders
+
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+ """
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_1 = torch.expm1(-h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+ else:
+ phi_1 = torch.expm1(h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
+ solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the second-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 0.5
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
+ s1), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_1 = torch.expm1(-h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
+ model_s1 - model_s)
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_1 = torch.expm1(h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
+ return_intermediate=False, solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 1. / 3.
+ if r2 is None:
+ r2 = 2. / 3.
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ lambda_s2 = lambda_s + r2 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ s2 = ns.inverse_lambda(lambda_s2)
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
+ s2), ns.marginal_std(t)
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_12 = torch.expm1(-r2 * h)
+ phi_1 = torch.expm1(-h)
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+ phi_2 = phi_1 / h + 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(sigma_s2 / sigma_s, dims) * x
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + expand_dims(alpha_t * phi_2, dims) * D1
+ - expand_dims(alpha_t * phi_3, dims) * D2
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_12 = torch.expm1(r2 * h)
+ phi_1 = torch.expm1(h)
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+ phi_2 = phi_1 / h - 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - expand_dims(sigma_t * phi_2, dims) * D1
+ - expand_dims(sigma_t * phi_3, dims) * D2
+ )
+
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+ else:
+ return x_t
+
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+ """
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_1, model_prev_0 = model_prev_list
+ t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+ t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0 = h_0 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ if self.predict_x0:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+ )
+ else:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+ """
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_1 = lambda_prev_1 - lambda_prev_2
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0, r1 = h_0 / h, h_1 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+ if self.predict_x0:
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
+ )
+ else:
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
+ )
+ return x_t
+
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
+ r2=None):
+ """
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+ elif order == 2:
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1)
+ elif order == 3:
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1, r2=r2)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+ """
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+ elif order == 2:
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ elif order == 3:
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
+ solver_type='dpm_solver'):
+ """
+ The adaptive step size solver based on singlestep DPM-Solver.
+ Args:
+ x: A pytorch tensor. The initial value at time `t_T`.
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ h_init: A `float`. The initial step size (for logSNR).
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+ """
+ ns = self.noise_schedule
+ s = t_T * torch.ones((x.shape[0],)).to(x)
+ lambda_s = ns.marginal_lambda(s)
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+ h = h_init * torch.ones_like(s).to(x)
+ x_prev = x
+ nfe = 0
+ if order == 2:
+ r1 = 0.5
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ solver_type=solver_type,
+ **kwargs)
+ elif order == 3:
+ r1, r2 = 1. / 3., 2. / 3.
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ return_intermediate=True,
+ solver_type=solver_type)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
+ solver_type=solver_type,
+ **kwargs)
+ else:
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+ while torch.abs((s - t_0)).mean() > t_err:
+ t = ns.inverse_lambda(lambda_s + h)
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+ E = norm_fn((x_higher - x_lower) / delta).max()
+ if torch.all(E <= 1.):
+ x = x_higher
+ s = t
+ x_prev = x_lower
+ lambda_s = ns.marginal_lambda(s)
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
+ nfe += order
+ print('adaptive solver nfe', nfe)
+ return x
+
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+ atol=0.0078, rtol=0.05,
+ ):
+ """
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+ =====================================================
+ We support the following algorithms for both noise prediction model and data prediction model:
+ - 'singlestep':
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+ The total number of function evaluations (NFE) == `steps`.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ - If `order` == 1:
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If `order` == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+ - 'multistep':
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+ We initialize the first `order` values by lower order multistep solvers.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ Denote K = steps.
+ - If `order` == 1:
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+ - If `order` == 3:
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+ - 'singlestep_fixed':
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+ - 'adaptive':
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+ (NFE) and the sample quality.
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+ =====================================================
+ Some advices for choosing the algorithm:
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+ skip_type='time_uniform', method='singlestep')
+ - For **guided sampling with large guidance scale** by DPMs:
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+ skip_type='time_uniform', method='multistep')
+ We support three types of `skip_type`:
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+ - 'time_quadratic': quadratic time for the time steps.
+ =====================================================
+ Args:
+ x: A pytorch tensor. The initial value at time `t_start`
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+ steps: A `int`. The total number of function evaluations (NFE).
+ t_start: A `float`. The starting time of the sampling.
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
+ t_end: A `float`. The ending time of the sampling.
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
+ For discrete-time DPMs:
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+ For continuous-time DPMs:
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+ order: A `int`. The order of DPM-Solver.
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
+ (such as CIFAR-10). However, we observed that such trick does not matter for
+ high-resolutional images. As it needs an additional NFE, we do not recommend
+ it for high-resolutional images.
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+ (especially for steps <= 10). So we recommend to set it to be `True`.
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ Returns:
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
+ """
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == 'adaptive':
+ with torch.no_grad():
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
+ solver_type=solver_type)
+ elif method == 'multistep':
+ assert steps >= order
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ assert timesteps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = timesteps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
+ solver_type=solver_type)
+ model_prev_list.append(self.model_fn(x, vec_t))
+ t_prev_list.append(vec_t)
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
+ vec_t = timesteps[step].expand(x.shape[0])
+ if lower_order_final and steps < 15:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
+ solver_type=solver_type)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ model_prev_list[-1] = self.model_fn(x, vec_t)
+ elif method in ['singlestep', 'singlestep_fixed']:
+ if method == 'singlestep':
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
+ skip_type=skip_type,
+ t_T=t_T, t_0=t_0,
+ device=device)
+ elif method == 'singlestep_fixed':
+ K = steps // order
+ orders = [order, ] * K
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
+ for i, order in enumerate(orders):
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
+ N=order, device=device)
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
+ h = lambda_inner[-1] - lambda_inner[0]
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+ if denoise_to_zero:
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+ return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
diff --git a/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/dpm_solver/sampler.py b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/dpm_solver/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d137b8cf36718c1c58faa09f9dd919e5fb2977b
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/dpm_solver/sampler.py
@@ -0,0 +1,87 @@
+"""SAMPLING ONLY."""
+import torch
+
+from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
+
+
+MODEL_TYPES = {
+ "eps": "noise",
+ "v": "v"
+}
+
+
+class DPMSolverSampler(object):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.model = model
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+
+ print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
+
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type=MODEL_TYPES[self.model.parameterization],
+ guidance_type="classifier-free",
+ condition=conditioning,
+ unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
+
+ return x.to(device), None
\ No newline at end of file
diff --git a/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/plms.py b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/plms.py
new file mode 100644
index 0000000000000000000000000000000000000000..7002a365d27168ced0a04e9a4d83e088f8284eae
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/plms.py
@@ -0,0 +1,244 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+from ldm.models.diffusion.sampling_util import norm_thresholding
+
+
+class PLMSSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ if ddim_eta != 0:
+ raise ValueError('ddim_eta must be 0 for PLMS')
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def plms_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+ old_eps = []
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps, t_next=ts_next,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ if dynamic_threshold is not None:
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
diff --git a/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/sampling_util.py b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/sampling_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eff02be6d7c54d43ee6680636ac0698dd3b3f33
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/sampling_util.py
@@ -0,0 +1,22 @@
+import torch
+import numpy as np
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def norm_thresholding(x0, value):
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
+ return x0 * (value / s)
+
+
+def spatial_norm_thresholding(x0, value):
+ # b c h w
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
+ return x0 * (value / s)
\ No newline at end of file
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/attention.py b/repositories/stable-diffusion-stability-ai/ldm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..d504d939f6a02cf45f028799d7d73b84500cee06
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/attention.py
@@ -0,0 +1,331 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from typing import Optional, Any
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', sim, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads.")
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x, context=None, mask=None):
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention
+ }
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None,
+ disable_self_attn=False, use_linear=False,
+ use_checkpoint=True):
+ super().__init__()
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+ for d in range(depth)]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/__init__.py b/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/model.py b/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b089eebbe1676d8249005bb9def002ff5180715b
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,852 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+from typing import Optional, Any
+
+from ldm.modules.attention import MemoryEfficientCrossAttention
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+ print("No module 'xformers'. Proceeding without it.")
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+ #
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(B, 1, out.shape[1], C)
+ .permute(0, 2, 1, 3)
+ .reshape(B, out.shape[1], C)
+ )
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
+ out = self.proj_out(out)
+ return x+out
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+ def forward(self, x, context=None, mask=None):
+ b, c, h, w = x.shape
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ out = super().forward(x, context=context, mask=mask)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
+ return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
+ attn_type = "vanilla-xformers"
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return MemoryEfficientAttnBlock(in_channels)
+ elif type == "memory-efficient-cross-attn":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ raise NotImplementedError()
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py b/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df6b5abfe8eff07f0c8e8703ba8aee90d45984b
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,786 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+from ldm.util import exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/upscaling.py b/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/upscaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..03816662098ce1ffac79bd939b892e867ab91988
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/upscaling.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from ldm.util import default
+
+
+class AbstractLowScaleModel(nn.Module):
+ # for concatenating a downsampled image to the latent representation
+ def __init__(self, noise_schedule_config=None):
+ super(AbstractLowScaleModel, self).__init__()
+ if noise_schedule_config is not None:
+ self.register_schedule(**noise_schedule_config)
+
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def forward(self, x):
+ return x, None
+
+ def decode(self, x):
+ return x
+
+
+class SimpleImageConcat(AbstractLowScaleModel):
+ # no noise level conditioning
+ def __init__(self):
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
+ self.max_noise_level = 0
+
+ def forward(self, x):
+ # fix to constant noise level
+ return x, torch.zeros(x.shape[0], device=x.device).long()
+
+
+class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
+ super().__init__(noise_schedule_config=noise_schedule_config)
+ self.max_noise_level = max_noise_level
+
+ def forward(self, x, noise_level=None):
+ if noise_level is None:
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
+ else:
+ assert isinstance(noise_level, torch.Tensor)
+ z = self.q_sample(x, noise_level)
+ return z, noise_level
+
+
+
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/util.py b/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..637363dfe34799e70cfdbcd11445212df9d9ca1f
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,270 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from ldm.util import instantiate_from_config
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled()}
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), \
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/distributions/__init__.py b/repositories/stable-diffusion-stability-ai/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/distributions/distributions.py b/repositories/stable-diffusion-stability-ai/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/ema.py b/repositories/stable-diffusion-stability-ai/ldm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/ema.py
@@ -0,0 +1,80 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
+ else torch.tensor(-1, dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.', '')
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/encoders/__init__.py b/repositories/stable-diffusion-stability-ai/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/encoders/modules.py b/repositories/stable-diffusion-stability-ai/ldm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..4edd5496b9e668ea72a5be39db9cca94b6a42f9b
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/encoders/modules.py
@@ -0,0 +1,213 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
+
+import open_clip
+from ldm.util import default, count_params
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class IdentityEncoder(AbstractEncoder):
+
+ def encode(self, x):
+ return x
+
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+ self.n_classes = n_classes
+ self.ucg_rate = ucg_rate
+
+ def forward(self, batch, key=None, disable_dropout=False):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ if self.ucg_rate > 0. and not disable_dropout:
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
+ c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
+ c = c.long()
+ c = self.embedding(c)
+ return c
+
+ def get_unconditional_conditioning(self, bs, device="cuda"):
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+ uc = torch.ones((bs,), device=device) * uc_class
+ uc = {self.key: uc}
+ return uc
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class FrozenT5Embedder(AbstractEncoder):
+ """Uses the T5 transformer encoder for text"""
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length # TODO: typical value?
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ #self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+ LAYERS = [
+ "last",
+ "pooled",
+ "hidden"
+ ]
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
+ freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ self.layer_idx = layer_idx
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert 0 <= abs(layer_idx) <= 12
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ #self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
+ if self.layer == "last":
+ z = outputs.last_hidden_state
+ elif self.layer == "pooled":
+ z = outputs.pooler_output[:, None, :]
+ else:
+ z = outputs.hidden_states[self.layer_idx]
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = [
+ #"pooled",
+ "last",
+ "penultimate"
+ ]
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
+ freeze=True, layer="last"):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEncoder):
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
+ clip_max_length=77, t5_max_length=77):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
+
+ def encode(self, text):
+ return self(text)
+
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
+
+
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/__init__.py b/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/bsrgan.py b/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/bsrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(30, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ elif i == 1:
+ image = add_blur(image, sf=sf)
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image":image}
+ return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+ """
+ This is an extended degradation model by combining
+ the degradation models of BSRGAN and Real-ESRGAN
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ use_shuffle: the degradation shuffle
+ use_sharp: sharpening the img
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ if use_sharp:
+ img = add_sharpening(img)
+ hq = img.copy()
+
+ if random.random() < shuffle_prob:
+ shuffle_order = random.sample(range(13), 13)
+ else:
+ shuffle_order = list(range(13))
+ # local shuffle for noise, JPEG is always the last one
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_resize(img, sf=sf)
+ elif i == 2:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 3:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 4:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 5:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ elif i == 6:
+ img = add_JPEG_noise(img)
+ elif i == 7:
+ img = add_blur(img, sf=sf)
+ elif i == 8:
+ img = add_resize(img, sf=sf)
+ elif i == 9:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 10:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 11:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 12:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ else:
+ print('check the shuffle!')
+
+ # resize to desired size
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+ return img, hq
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ print(img)
+ img = util.uint2single(img)
+ print(img)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_lq = deg_fn(img)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/bsrgan_light.py b/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100644
index 0000000000000000000000000000000000000000..808c7f882cb75e2ba2340d5b55881d11927351f0
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,651 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+
+ wd2 = wd2/4
+ wd = wd/4
+
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+ img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(80, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ # elif i == 1:
+ # image = add_blur(image, sf=sf)
+
+ if i == 0:
+ pass
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.8:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+ #
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ if up:
+ image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
+ example = {"image": image}
+ return example
+
+
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_hq = img
+ img_lq = deg_fn(img)["image"]
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/utils/test.png b/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/utils/test.png
new file mode 100644
index 0000000000000000000000000000000000000000..e720ed04ac7e1e7938d367e692fb6a742c54a24c
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/utils/test.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92e516278f0d3e85e84cfb55b43338e12d5896a0ee3833aafdf378025457d753
+size 441072
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/utils_image.py b/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/utils_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+ plt.figure(figsize=figsize)
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+ plt.figure(figsize=figsize)
+ ax3 = plt.axes(projection='3d')
+
+ w, h = Z.shape[:2]
+ xx = np.arange(0,w,1)
+ yy = np.arange(0,h,1)
+ X, Y = np.meshgrid(xx, yy)
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+ paths = None # return None if dataroot is None
+ if dataroot is not None:
+ paths = sorted(_get_paths_from_images(dataroot))
+ return paths
+
+
+def _get_paths_from_images(path):
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+ w, h = img.shape[:2]
+ patches = []
+ if w > p_max and h > p_max:
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+ w1.append(w-p_size)
+ h1.append(h-p_size)
+# print(w1)
+# print(h1)
+ for i in w1:
+ for j in h1:
+ patches.append(img[i:i+p_size, j:j+p_size,:])
+ else:
+ patches.append(img)
+
+ return patches
+
+
+def imssave(imgs, img_path):
+ """
+ imgs: list, N images of size WxHxC
+ """
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+ for i, img in enumerate(imgs):
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+ cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+ """
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+ will be splitted.
+ Args:
+ original_dataroot:
+ taget_dataroot:
+ p_size: size of small images
+ p_overlap: patch size in training is a good choice
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+ """
+ paths = get_image_paths(original_dataroot)
+ for img_path in paths:
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
+ img = imread_uint(img_path, n_channels=n_channels)
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+ #if original_dataroot == taget_dataroot:
+ #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+
+
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
+ img = np.expand_dims(img, axis=2) # HxWx1
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
+ return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+ # read image by cv2
+ # return: Numpy float32, HWC, BGR, [0,1]
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <---> numpy(unit)
+# numpy(single) <---> tensor
+# numpy(unit) <---> tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <---> numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+ return np.float32(img/255.)
+
+
+def single2uint(img):
+
+ return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+ return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+ return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <---> tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <---> tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+
+ return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ elif img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return img
+
+
+def single2tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array of BGR channel order
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+ return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return np.flipud(np.rot90(img))
+ elif mode == 2:
+ return np.flipud(img)
+ elif mode == 3:
+ return np.rot90(img, k=3)
+ elif mode == 4:
+ return np.flipud(np.rot90(img, k=2))
+ elif mode == 5:
+ return np.rot90(img)
+ elif mode == 6:
+ return np.rot90(img, k=2)
+ elif mode == 7:
+ return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.rot90(1, [2, 3]).flip([2])
+ elif mode == 2:
+ return img.flip([2])
+ elif mode == 3:
+ return img.rot90(3, [2, 3])
+ elif mode == 4:
+ return img.rot90(2, [2, 3]).flip([2])
+ elif mode == 5:
+ return img.rot90(1, [2, 3])
+ elif mode == 6:
+ return img.rot90(2, [2, 3])
+ elif mode == 7:
+ return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ img_size = img.size()
+ img_np = img.data.cpu().numpy()
+ if len(img_size) == 3:
+ img_np = np.transpose(img_np, (1, 2, 0))
+ elif len(img_size) == 4:
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
+ img_np = augment_img(img_np, mode=mode)
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+ if len(img_size) == 3:
+ img_tensor = img_tensor.permute(2, 0, 1)
+ elif len(img_size) == 4:
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+ return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.transpose(1, 0, 2)
+ elif mode == 2:
+ return img[::-1, :, :]
+ elif mode == 3:
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 4:
+ return img[:, ::-1, :]
+ elif mode == 5:
+ img = img[:, ::-1, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 6:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ return img
+ elif mode == 7:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+ # horizontal flip OR rotate
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+
+
+def shave(img_in, border=0):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ h, w = img.shape[:2]
+ img = img[border:h-border, border:w-border]
+ return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+
+ return out_2.numpy()
+
+
+if __name__ == '__main__':
+ print('---')
+# img = imread_uint('test.bmp', 3)
+# img = uint2single(img)
+# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/midas/__init__.py b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/midas/api.py b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..b58ebbffd942a2fc22264f0ab47e400c26b9f41c
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/api.py
@@ -0,0 +1,170 @@
+# based on https://github.com/isl-org/MiDaS
+
+import cv2
+import torch
+import torch.nn as nn
+from torchvision.transforms import Compose
+
+from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
+from ldm.modules.midas.midas.midas_net import MidasNet
+from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
+from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
+
+
+ISL_PATHS = {
+ "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
+ "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
+ "midas_v21": "",
+ "midas_v21_small": "",
+}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def load_midas_transform(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load transform only
+ if model_type == "dpt_large": # DPT-Large
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ elif model_type == "midas_v21_small":
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ else:
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return transform
+
+
+def load_model(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load network
+ model_path = ISL_PATHS[model_type]
+ if model_type == "dpt_large": # DPT-Large
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ model = MidasNet(model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ elif model_type == "midas_v21_small":
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+ non_negative=True, blocks={'expand': True})
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return model.eval(), transform
+
+
+class MiDaSInference(nn.Module):
+ MODEL_TYPES_TORCH_HUB = [
+ "DPT_Large",
+ "DPT_Hybrid",
+ "MiDaS_small"
+ ]
+ MODEL_TYPES_ISL = [
+ "dpt_large",
+ "dpt_hybrid",
+ "midas_v21",
+ "midas_v21_small",
+ ]
+
+ def __init__(self, model_type):
+ super().__init__()
+ assert (model_type in self.MODEL_TYPES_ISL)
+ model, _ = load_model(model_type)
+ self.model = model
+ self.model.train = disabled_train
+
+ def forward(self, x):
+ # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
+ # NOTE: we expect that the correct transform has been called during dataloading.
+ with torch.no_grad():
+ prediction = self.model(x)
+ prediction = torch.nn.functional.interpolate(
+ prediction.unsqueeze(1),
+ size=x.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
+ return prediction
+
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__init__.py b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/base_model.py b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+
+ self.load_state_dict(parameters)
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/blocks.py b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+ if backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+
+ return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand==True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ out_shape4 = out_shape*8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+ return pretrained
+
+
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+
+ return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+
+ return x
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+ # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
+
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/dpt_depth.py b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock,
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_vit,
+)
+
+
+def _make_fusion_block(features, use_bn):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+
+
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ ):
+
+ super(DPT, self).__init__()
+
+ self.channels_last = channels_last
+
+ hooks = {
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }
+
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks[backbone],
+ use_readout=readout,
+ )
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv = head
+
+
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return out
+
+
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+
+ head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ super().__init__(head, **kwargs)
+
+ if path is not None:
+ self.load(path)
+
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
+
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net.py b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet, self).__init__()
+
+ use_pretrained = False if path is None else True
+
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net_custom.py b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+ blocks={'expand': True}):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet_small, self).__init__()
+
+ use_pretrained = False if path else True
+
+ self.channels_last = channels_last
+ self.blocks = blocks
+ self.backbone = backbone
+
+ self.groups = 1
+
+ features1=features
+ features2=features
+ features3=features
+ features4=features
+ self.expand = False
+ if "expand" in self.blocks and self.blocks['expand'] == True:
+ self.expand = True
+ features1=features
+ features2=features*2
+ features3=features*4
+ features4=features*8
+
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+
+ self.scratch.activation = nn.ReLU(False)
+
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+ self.scratch.activation,
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+ if self.channels_last==True:
+ print("self.channels_last = ", self.channels_last)
+ x.contiguous(memory_format=torch.channels_last)
+
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+ prev_previous_type = nn.Identity()
+ prev_previous_name = ''
+ previous_type = nn.Identity()
+ previous_name = ''
+ for name, module in m.named_modules():
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+ # print("FUSED ", prev_previous_name, previous_name, name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+ # print("FUSED ", prev_previous_name, previous_name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+ # print("FUSED ", previous_name, name)
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+ prev_previous_type = previous_type
+ prev_previous_name = previous_name
+ previous_type = type(module)
+ previous_name = name
\ No newline at end of file
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/transforms.py b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ return sample
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/vit.py b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+ features = torch.cat((x[:, self.start_index :], readout), -1)
+
+ return self.project(features)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+
+
+def forward_vit(pretrained, x):
+ b, c, h, w = x.shape
+
+ glob = pretrained.model.forward_flex(x)
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index :],
+ )
+
+ gs_old = int(math.sqrt(len(posemb_grid)))
+
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+
+ B = x.shape[0]
+
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x
+
+
+activations = {}
+
+
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+
+ return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+ return readout_oper
+
+
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model(
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+ )
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ start_index=2,
+ )
+
+
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+
+ if use_vit_only == True:
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ else:
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+ get_activation("1")
+ )
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+ get_activation("2")
+ )
+
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ if use_vit_only == True:
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ else:
+ pretrained.act_postprocess1 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
diff --git a/repositories/stable-diffusion-stability-ai/ldm/modules/midas/utils.py b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/modules/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+ """Read pfm file.
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ tuple: (data, scale)
+ """
+ with open(path, "rb") as file:
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header.decode("ascii") == "PF":
+ color = True
+ elif header.decode("ascii") == "Pf":
+ color = False
+ else:
+ raise Exception("Not a PFM file: " + path)
+
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception("Malformed PFM header.")
+
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0:
+ # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ # big-endian
+ endian = ">"
+
+ data = np.fromfile(file, endian + "f")
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+
+ return data, scale
+
+
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+
+ with open(path, "wb") as file:
+ color = None
+
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+
+ file.write("%f\n".encode() % scale)
+
+ image.tofile(file)
+
+
+def read_image(path):
+ """Read image and output RGB image (0-1).
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+ return img
+
+
+def resize_image(img):
+ """Resize image and make it fit for network.
+
+ Args:
+ img (array): image
+
+ Returns:
+ tensor: data ready for network
+ """
+ height_orig = img.shape[0]
+ width_orig = img.shape[1]
+
+ if width_orig > height_orig:
+ scale = width_orig / 384
+ else:
+ scale = height_orig / 384
+
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+ img_resized = (
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+ )
+ img_resized = img_resized.unsqueeze(0)
+
+ return img_resized
+
+
+def resize_depth(depth, width, height):
+ """Resize depth map and bring to CPU (numpy).
+
+ Args:
+ depth (tensor): depth
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ array: processed depth
+ """
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+ depth_resized = cv2.resize(
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+ )
+
+ return depth_resized
+
+def write_depth(path, depth, bits=1):
+ """Write depth map to pfm and png file.
+
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ """
+ write_pfm(path + ".pfm", depth.astype(np.float32))
+
+ depth_min = depth.min()
+ depth_max = depth.max()
+
+ max_val = (2**(8*bits))-1
+
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape, dtype=depth.type)
+
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+
+ return
diff --git a/repositories/stable-diffusion-stability-ai/ldm/util.py b/repositories/stable-diffusion-stability-ai/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c09ca1c72f7ceb3f9d7f9546aae5561baf62b13
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/ldm/util.py
@@ -0,0 +1,197 @@
+import importlib
+
+import torch
+from torch import optim
+import numpy as np
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x,torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+class AdamWwithEMAandWings(optim.Optimizer):
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
+ ema_power=1., param_names=()):
+ """AdamW that saves EMA versions of the parameters."""
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= ema_decay <= 1.0:
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
+ ema_power=ema_power, param_names=param_names)
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super().__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ params_with_grad = []
+ grads = []
+ exp_avgs = []
+ exp_avg_sqs = []
+ ema_params_with_grad = []
+ state_sums = []
+ max_exp_avg_sqs = []
+ state_steps = []
+ amsgrad = group['amsgrad']
+ beta1, beta2 = group['betas']
+ ema_decay = group['ema_decay']
+ ema_power = group['ema_power']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ params_with_grad.append(p)
+ if p.grad.is_sparse:
+ raise RuntimeError('AdamW does not support sparse gradients')
+ grads.append(p.grad)
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of parameter values
+ state['param_exp_avg'] = p.detach().float().clone()
+
+ exp_avgs.append(state['exp_avg'])
+ exp_avg_sqs.append(state['exp_avg_sq'])
+ ema_params_with_grad.append(state['param_exp_avg'])
+
+ if amsgrad:
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
+
+ # update the steps for each param group update
+ state['step'] += 1
+ # record the step after step update
+ state_steps.append(state['step'])
+
+ optim._functional.adamw(params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ amsgrad=amsgrad,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ maximize=False)
+
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
+
+ return loss
\ No newline at end of file
diff --git a/repositories/stable-diffusion-stability-ai/modelcard.md b/repositories/stable-diffusion-stability-ai/modelcard.md
new file mode 100644
index 0000000000000000000000000000000000000000..449e16f261ee23dfc4944caf1b56effad6791415
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/modelcard.md
@@ -0,0 +1,140 @@
+# Stable Diffusion v2 Model Card
+This model card focuses on the models associated with the Stable Diffusion v2, available [here](https://github.com/Stability-AI/stablediffusion/).
+
+## Model Details
+- **Developed by:** Robin Rombach, Patrick Esser
+- **Model type:** Diffusion-based text-to-image generation model
+- **Language(s):** English
+- **License:** CreativeML Open RAIL++-M License
+- **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([OpenCLIP-ViT/H](https://github.com/mlfoundations/open_clip)).
+- **Resources for more information:** [GitHub Repository](https://github.com/Stability-AI/).
+- **Cite as:**
+
+ @InProceedings{Rombach_2022_CVPR,
+ author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
+ title = {High-Resolution Image Synthesis With Latent Diffusion Models},
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month = {June},
+ year = {2022},
+ pages = {10684-10695}
+ }
+
+# Uses
+
+## Direct Use
+The model is intended for research purposes only. Possible research areas and tasks include
+
+- Safe deployment of models which have the potential to generate harmful content.
+- Probing and understanding the limitations and biases of generative models.
+- Generation of artworks and use in design and other artistic processes.
+- Applications in educational or creative tools.
+- Research on generative models.
+
+Excluded uses are described below.
+
+ ### Misuse, Malicious Use, and Out-of-Scope Use
+_Note: This section is originally taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), was used for Stable Diffusion v1, but applies in the same way to Stable Diffusion v2_.
+
+The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
+
+#### Out-of-Scope Use
+The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
+
+#### Misuse and Malicious Use
+Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
+
+- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
+- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
+- Impersonating individuals without their consent.
+- Sexual content without consent of the people who might see it.
+- Mis- and disinformation
+- Representations of egregious violence and gore
+- Sharing of copyrighted or licensed material in violation of its terms of use.
+- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
+
+## Limitations and Bias
+
+### Limitations
+
+- The model does not achieve perfect photorealism
+- The model cannot render legible text
+- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
+- Faces and people in general may not be generated properly.
+- The model was trained mainly with English captions and will not work as well in other languages.
+- The autoencoding part of the model is lossy
+- The model was trained on a subset of the large-scale dataset
+ [LAION-5B](https://laion.ai/blog/laion-5b/), which contains adult, violent and sexual content. To partially mitigate this, we have filtered the dataset using LAION's NFSW detector (see Training section).
+
+### Bias
+While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
+Stable Diffusion vw was primarily trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
+which consists of images that are limited to English descriptions.
+Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
+This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
+ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
+Stable Diffusion v2 mirrors and exacerbates biases to such a degree that viewer discretion must be advised irrespective of the input or its intent.
+
+
+## Training
+
+**Training Data**
+The model developers used the following dataset for training the model:
+
+- LAION-5B and subsets (details below). The training data is further filtered using LAION's NSFW detector, with a "p_unsafe" score of 0.1 (conservative). For more details, please refer to LAION-5B's [NeurIPS 2022](https://openreview.net/forum?id=M3Y74vmsMcY) paper and reviewer discussions on the topic.
+
+**Training Procedure**
+Stable Diffusion v2 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
+
+- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
+- Text prompts are encoded through the OpenCLIP-ViT/H text-encoder.
+- The output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
+- The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet. We also use the so-called _v-objective_, see https://arxiv.org/abs/2202.00512.
+
+We currently provide the following checkpoints:
+
+- `512-base-ema.ckpt`: 550k steps at resolution `256x256` on a subset of [LAION-5B](https://laion.ai/blog/laion-5b/) filtered for explicit pornographic material, using the [LAION-NSFW classifier](https://github.com/LAION-AI/CLIP-based-NSFW-Detector) with `punsafe=0.1` and an [aesthetic score](https://github.com/christophschuhmann/improved-aesthetic-predictor) >= `4.5`.
+ 850k steps at resolution `512x512` on the same dataset with resolution `>= 512x512`.
+- `768-v-ema.ckpt`: Resumed from `512-base-ema.ckpt` and trained for 150k steps using a [v-objective](https://arxiv.org/abs/2202.00512) on the same dataset. Resumed for another 140k steps on a `768x768` subset of our dataset.
+- `512-depth-ema.ckpt`: Resumed from `512-base-ema.ckpt` and finetuned for 200k steps. Added an extra input channel to process the (relative) depth prediction produced by [MiDaS](https://github.com/isl-org/MiDaS) (`dpt_hybrid`) which is used as an additional conditioning.
+The additional input channels of the U-Net which process this extra information were zero-initialized.
+- `512-inpainting-ema.ckpt`: Resumed from `512-base-ema.ckpt` and trained for another 200k steps. Follows the mask-generation strategy presented in [LAMA](https://github.com/saic-mdal/lama) which, in combination with the latent VAE representations of the masked image, are used as an additional conditioning.
+The additional input channels of the U-Net which process this extra information were zero-initialized. The same strategy was used to train the [1.5-inpainting checkpoint](https://github.com/saic-mdal/lama).
+- `x4-upscaling-ema.ckpt`: Trained for 1.25M steps on a 10M subset of LAION containing images `>2048x2048`. The model was trained on crops of size `512x512` and is a text-guided [latent upscaling diffusion model](https://arxiv.org/abs/2112.10752).
+In addition to the textual input, it receives a `noise_level` as an input parameter, which can be used to add noise to the low-resolution input according to a [predefined diffusion schedule](configs/stable-diffusion/x4-upscaling.yaml).
+
+- **Hardware:** 32 x 8 x A100 GPUs
+- **Optimizer:** AdamW
+- **Gradient Accumulations**: 1
+- **Batch:** 32 x 8 x 2 x 4 = 2048
+- **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
+
+## Evaluation Results
+Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
+5.0, 6.0, 7.0, 8.0) and 50 steps DDIM sampling steps show the relative improvements of the checkpoints:
+
+![pareto](assets/model-variants.jpg)
+
+Evaluated using 50 DDIM steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
+
+## Environmental Impact
+
+**Stable Diffusion v1** **Estimated Emissions**
+Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
+
+- **Hardware Type:** A100 PCIe 40GB
+- **Hours used:** 200000
+- **Cloud Provider:** AWS
+- **Compute Region:** US-east
+- **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 15000 kg CO2 eq.
+
+## Citation
+ @InProceedings{Rombach_2022_CVPR,
+ author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
+ title = {High-Resolution Image Synthesis With Latent Diffusion Models},
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month = {June},
+ year = {2022},
+ pages = {10684-10695}
+ }
+
+*This model card was written by: Robin Rombach, Patrick Esser and David Ha and is based on the [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion/blob/main/Stable_Diffusion_v1_Model_Card.md) and [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
diff --git a/repositories/stable-diffusion-stability-ai/requirements.txt b/repositories/stable-diffusion-stability-ai/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2404caac780b0048972e40630b27d3e9f1e4cc5e
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/requirements.txt
@@ -0,0 +1,16 @@
+albumentations==0.4.3
+opencv-python
+pudb==2019.2
+imageio==2.9.0
+imageio-ffmpeg==0.4.2
+pytorch-lightning==1.4.2
+torchmetrics==0.6
+omegaconf==2.1.1
+test-tube>=0.7.5
+streamlit>=0.73.1
+einops==0.3.0
+transformers==4.19.2
+webdataset==0.2.5
+open-clip-torch==2.7.0
+gradio==3.11
+-e .
diff --git a/repositories/stable-diffusion-stability-ai/scripts/gradio/depth2img.py b/repositories/stable-diffusion-stability-ai/scripts/gradio/depth2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..c791a4d0b2a510b3525658f4d852d14704ea9f1a
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/scripts/gradio/depth2img.py
@@ -0,0 +1,184 @@
+import sys
+import torch
+import numpy as np
+import gradio as gr
+from PIL import Image
+from omegaconf import OmegaConf
+from einops import repeat, rearrange
+from pytorch_lightning import seed_everything
+from imwatermark import WatermarkEncoder
+
+from scripts.txt2img import put_watermark
+from ldm.util import instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.data.util import AddMiDaS
+
+torch.set_grad_enabled(False)
+
+
+def initialize_model(config, ckpt):
+ config = OmegaConf.load(config)
+ model = instantiate_from_config(config.model)
+ model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
+
+ device = torch.device(
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+ sampler = DDIMSampler(model)
+ return sampler
+
+
+def make_batch_sd(
+ image,
+ txt,
+ device,
+ num_samples=1,
+ model_type="dpt_hybrid"
+):
+ image = np.array(image.convert("RGB"))
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
+ midas_trafo = AddMiDaS(model_type=model_type)
+ batch = {
+ "jpg": image,
+ "txt": num_samples * [txt],
+ }
+ batch = midas_trafo(batch)
+ batch["jpg"] = rearrange(batch["jpg"], 'h w c -> 1 c h w')
+ batch["jpg"] = repeat(batch["jpg"].to(device=device),
+ "1 ... -> n ...", n=num_samples)
+ batch["midas_in"] = repeat(torch.from_numpy(batch["midas_in"][None, ...]).to(
+ device=device), "1 ... -> n ...", n=num_samples)
+ return batch
+
+
+def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
+ do_full_sample=False):
+ device = torch.device(
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = sampler.model
+ seed_everything(seed)
+
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
+ wm = "SDV2"
+ wm_encoder = WatermarkEncoder()
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
+
+ with torch.no_grad(),\
+ torch.autocast("cuda"):
+ batch = make_batch_sd(
+ image, txt=prompt, device=device, num_samples=num_samples)
+ z = model.get_first_stage_encoding(model.encode_first_stage(
+ batch[model.first_stage_key])) # move to latent space
+ c = model.cond_stage_model.encode(batch["txt"])
+ c_cat = list()
+ for ck in model.concat_keys:
+ cc = batch[ck]
+ cc = model.depth_model(cc)
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+ keepdim=True)
+ display_depth = (cc - depth_min) / (depth_max - depth_min)
+ depth_image = Image.fromarray(
+ (display_depth[0, 0, ...].cpu().numpy() * 255.).astype(np.uint8))
+ cc = torch.nn.functional.interpolate(
+ cc,
+ size=z.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+ keepdim=True)
+ cc = 2. * (cc - depth_min) / (depth_max - depth_min) - 1.
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ # cond
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
+
+ # uncond cond
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
+ if not do_full_sample:
+ # encode (scaled latent)
+ z_enc = sampler.stochastic_encode(
+ z, torch.tensor([t_enc] * num_samples).to(model.device))
+ else:
+ z_enc = torch.randn_like(z)
+ # decode it
+ samples = sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale,
+ unconditional_conditioning=uc_full, callback=callback)
+ x_samples_ddim = model.decode_first_stage(samples)
+ result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
+ return [depth_image] + [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
+
+
+def pad_image(input_image):
+ pad_w, pad_h = np.max(((2, 2), np.ceil(
+ np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
+ im_padded = Image.fromarray(
+ np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
+ return im_padded
+
+
+def predict(input_image, prompt, steps, num_samples, scale, seed, eta, strength):
+ init_image = input_image.convert("RGB")
+ image = pad_image(init_image) # resize to integer multiple of 32
+
+ sampler.make_schedule(steps, ddim_eta=eta, verbose=True)
+ assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
+ do_full_sample = strength == 1.
+ t_enc = min(int(strength * steps), steps-1)
+ result = paint(
+ sampler=sampler,
+ image=image,
+ prompt=prompt,
+ t_enc=t_enc,
+ seed=seed,
+ scale=scale,
+ num_samples=num_samples,
+ callback=None,
+ do_full_sample=do_full_sample
+ )
+ return result
+
+
+sampler = initialize_model(sys.argv[1], sys.argv[2])
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Stable Diffusion Depth2Img")
+
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="pil")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(
+ label="Images", minimum=1, maximum=4, value=1, step=1)
+ ddim_steps = gr.Slider(label="Steps", minimum=1,
+ maximum=50, value=50, step=1)
+ scale = gr.Slider(
+ label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1
+ )
+ strength = gr.Slider(
+ label="Strength", minimum=0.0, maximum=1.0, value=0.9, step=0.01
+ )
+ seed = gr.Slider(
+ label="Seed",
+ minimum=0,
+ maximum=2147483647,
+ step=1,
+ randomize=True,
+ )
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
+ with gr.Column():
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(
+ grid=[2], height="auto")
+
+ run_button.click(fn=predict, inputs=[
+ input_image, prompt, ddim_steps, num_samples, scale, seed, eta, strength], outputs=[gallery])
+
+
+block.launch()
diff --git a/repositories/stable-diffusion-stability-ai/scripts/gradio/inpainting.py b/repositories/stable-diffusion-stability-ai/scripts/gradio/inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..09d44f3ddc528011d7421966915b93d0e2803ba5
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/scripts/gradio/inpainting.py
@@ -0,0 +1,195 @@
+import sys
+import cv2
+import torch
+import numpy as np
+import gradio as gr
+from PIL import Image
+from omegaconf import OmegaConf
+from einops import repeat
+from imwatermark import WatermarkEncoder
+from pathlib import Path
+
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.util import instantiate_from_config
+
+
+torch.set_grad_enabled(False)
+
+
+def put_watermark(img, wm_encoder=None):
+ if wm_encoder is not None:
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+ img = wm_encoder.encode(img, 'dwtDct')
+ img = Image.fromarray(img[:, :, ::-1])
+ return img
+
+
+def initialize_model(config, ckpt):
+ config = OmegaConf.load(config)
+ model = instantiate_from_config(config.model)
+
+ model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
+
+ device = torch.device(
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+ sampler = DDIMSampler(model)
+
+ return sampler
+
+
+def make_batch_sd(
+ image,
+ mask,
+ txt,
+ device,
+ num_samples=1):
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ batch = {
+ "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
+ "txt": num_samples * [txt],
+ "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
+ "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
+ }
+ return batch
+
+
+def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512):
+ device = torch.device(
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = sampler.model
+
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
+ wm = "SDV2"
+ wm_encoder = WatermarkEncoder()
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
+
+ prng = np.random.RandomState(seed)
+ start_code = prng.randn(num_samples, 4, h // 8, w // 8)
+ start_code = torch.from_numpy(start_code).to(
+ device=device, dtype=torch.float32)
+
+ with torch.no_grad(), \
+ torch.autocast("cuda"):
+ batch = make_batch_sd(image, mask, txt=prompt,
+ device=device, num_samples=num_samples)
+
+ c = model.cond_stage_model.encode(batch["txt"])
+
+ c_cat = list()
+ for ck in model.concat_keys:
+ cc = batch[ck].float()
+ if ck != model.masked_image_key:
+ bchw = [num_samples, 4, h // 8, w // 8]
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+ else:
+ cc = model.get_first_stage_encoding(
+ model.encode_first_stage(cc))
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+
+ # cond
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
+
+ # uncond cond
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
+
+ shape = [model.channels, h // 8, w // 8]
+ samples_cfg, intermediates = sampler.sample(
+ ddim_steps,
+ num_samples,
+ shape,
+ cond,
+ verbose=False,
+ eta=1.0,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=uc_full,
+ x_T=start_code,
+ )
+ x_samples_ddim = model.decode_first_stage(samples_cfg)
+
+ result = torch.clamp((x_samples_ddim + 1.0) / 2.0,
+ min=0.0, max=1.0)
+
+ result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
+ return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
+
+def pad_image(input_image):
+ pad_w, pad_h = np.max(((2, 2), np.ceil(
+ np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
+ im_padded = Image.fromarray(
+ np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
+ return im_padded
+
+def predict(input_image, prompt, ddim_steps, num_samples, scale, seed):
+ init_image = input_image["image"].convert("RGB")
+ init_mask = input_image["mask"].convert("RGB")
+ image = pad_image(init_image) # resize to integer multiple of 32
+ mask = pad_image(init_mask) # resize to integer multiple of 32
+ width, height = image.size
+ print("Inpainting...", width, height)
+
+ result = inpaint(
+ sampler=sampler,
+ image=image,
+ mask=mask,
+ prompt=prompt,
+ seed=seed,
+ scale=scale,
+ ddim_steps=ddim_steps,
+ num_samples=num_samples,
+ h=height, w=width
+ )
+
+ return result
+
+
+sampler = initialize_model(sys.argv[1], sys.argv[2])
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Stable Diffusion Inpainting")
+
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', tool='sketch', type="pil")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(
+ label="Images", minimum=1, maximum=4, value=4, step=1)
+ ddim_steps = gr.Slider(label="Steps", minimum=1,
+ maximum=50, value=45, step=1)
+ scale = gr.Slider(
+ label="Guidance Scale", minimum=0.1, maximum=30.0, value=10, step=0.1
+ )
+ seed = gr.Slider(
+ label="Seed",
+ minimum=0,
+ maximum=2147483647,
+ step=1,
+ randomize=True,
+ )
+ with gr.Column():
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(
+ grid=[2], height="auto")
+
+ run_button.click(fn=predict, inputs=[
+ input_image, prompt, ddim_steps, num_samples, scale, seed], outputs=[gallery])
+
+
+block.launch()
diff --git a/repositories/stable-diffusion-stability-ai/scripts/gradio/superresolution.py b/repositories/stable-diffusion-stability-ai/scripts/gradio/superresolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d08fbfae4f9639165e669f3c69c76763c5b32a8
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/scripts/gradio/superresolution.py
@@ -0,0 +1,197 @@
+import sys
+import torch
+import numpy as np
+import gradio as gr
+from PIL import Image
+from omegaconf import OmegaConf
+from einops import repeat, rearrange
+from pytorch_lightning import seed_everything
+from imwatermark import WatermarkEncoder
+
+from scripts.txt2img import put_watermark
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentUpscaleFinetuneDiffusion
+from ldm.util import exists, instantiate_from_config
+
+
+torch.set_grad_enabled(False)
+
+
+def initialize_model(config, ckpt):
+ config = OmegaConf.load(config)
+ model = instantiate_from_config(config.model)
+ model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
+
+ device = torch.device(
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+ sampler = DDIMSampler(model)
+ return sampler
+
+
+def make_batch_sd(
+ image,
+ txt,
+ device,
+ num_samples=1,
+):
+ image = np.array(image.convert("RGB"))
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+ batch = {
+ "lr": rearrange(image, 'h w c -> 1 c h w'),
+ "txt": num_samples * [txt],
+ }
+ batch["lr"] = repeat(batch["lr"].to(device=device),
+ "1 ... -> n ...", n=num_samples)
+ return batch
+
+
+def make_noise_augmentation(model, batch, noise_level=None):
+ x_low = batch[model.low_scale_key]
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
+ x_aug, noise_level = model.low_scale_model(x_low, noise_level)
+ return x_aug, noise_level
+
+
+def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
+ device = torch.device(
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = sampler.model
+ seed_everything(seed)
+ prng = np.random.RandomState(seed)
+ start_code = prng.randn(num_samples, model.channels, h, w)
+ start_code = torch.from_numpy(start_code).to(
+ device=device, dtype=torch.float32)
+
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
+ wm = "SDV2"
+ wm_encoder = WatermarkEncoder()
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
+ with torch.no_grad(),\
+ torch.autocast("cuda"):
+ batch = make_batch_sd(
+ image, txt=prompt, device=device, num_samples=num_samples)
+ c = model.cond_stage_model.encode(batch["txt"])
+ c_cat = list()
+ if isinstance(model, LatentUpscaleFinetuneDiffusion):
+ for ck in model.concat_keys:
+ cc = batch[ck]
+ if exists(model.reshuffle_patch_size):
+ assert isinstance(model.reshuffle_patch_size, int)
+ cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
+ p1=model.reshuffle_patch_size, p2=model.reshuffle_patch_size)
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ # cond
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
+ # uncond cond
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
+ elif isinstance(model, LatentUpscaleDiffusion):
+ x_augment, noise_level = make_noise_augmentation(
+ model, batch, noise_level)
+ cond = {"c_concat": [x_augment],
+ "c_crossattn": [c], "c_adm": noise_level}
+ # uncond cond
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
+ uc_full = {"c_concat": [x_augment], "c_crossattn": [
+ uc_cross], "c_adm": noise_level}
+ else:
+ raise NotImplementedError()
+
+ shape = [model.channels, h, w]
+ samples, intermediates = sampler.sample(
+ steps,
+ num_samples,
+ shape,
+ cond,
+ verbose=False,
+ eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=uc_full,
+ x_T=start_code,
+ callback=callback
+ )
+ with torch.no_grad():
+ x_samples_ddim = model.decode_first_stage(samples)
+ result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
+ return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
+
+
+def pad_image(input_image):
+ pad_w, pad_h = np.max(((2, 2), np.ceil(
+ np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size
+ im_padded = Image.fromarray(
+ np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
+ return im_padded
+
+
+def predict(input_image, prompt, steps, num_samples, scale, seed, eta, noise_level):
+ init_image = input_image.convert("RGB")
+ image = pad_image(init_image) # resize to integer multiple of 32
+ width, height = image.size
+
+ noise_level = torch.Tensor(
+ num_samples * [noise_level]).to(sampler.model.device).long()
+ sampler.make_schedule(steps, ddim_eta=eta, verbose=True)
+ result = paint(
+ sampler=sampler,
+ image=image,
+ prompt=prompt,
+ seed=seed,
+ scale=scale,
+ h=height, w=width, steps=steps,
+ num_samples=num_samples,
+ callback=None,
+ noise_level=noise_level
+ )
+ return result
+
+
+sampler = initialize_model(sys.argv[1], sys.argv[2])
+
+block = gr.Blocks().queue()
+with block:
+ with gr.Row():
+ gr.Markdown("## Stable Diffusion Upscaling")
+
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(source='upload', type="pil")
+ gr.Markdown(
+ "Tip: Add a description of the object that should be upscaled, e.g.: 'a professional photograph of a cat")
+ prompt = gr.Textbox(label="Prompt")
+ run_button = gr.Button(label="Run")
+ with gr.Accordion("Advanced options", open=False):
+ num_samples = gr.Slider(
+ label="Number of Samples", minimum=1, maximum=4, value=1, step=1)
+ steps = gr.Slider(label="DDIM Steps", minimum=2,
+ maximum=200, value=75, step=1)
+ scale = gr.Slider(
+ label="Scale", minimum=0.1, maximum=30.0, value=10, step=0.1
+ )
+ seed = gr.Slider(
+ label="Seed",
+ minimum=0,
+ maximum=2147483647,
+ step=1,
+ randomize=True,
+ )
+ eta = gr.Number(label="eta (DDIM)",
+ value=0.0, min=0.0, max=1.0)
+ noise_level = None
+ if isinstance(sampler.model, LatentUpscaleDiffusion):
+ # TODO: make this work for all models
+ noise_level = gr.Number(
+ label="Noise Augmentation", min=0, max=350, value=20, step=1)
+
+ with gr.Column():
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(
+ grid=[2], height="auto")
+
+ run_button.click(fn=predict, inputs=[
+ input_image, prompt, steps, num_samples, scale, seed, eta, noise_level], outputs=[gallery])
+
+
+block.launch()
diff --git a/repositories/stable-diffusion-stability-ai/scripts/img2img.py b/repositories/stable-diffusion-stability-ai/scripts/img2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..9085ba9d37ea6402b9ee543e82f7d8c56a1c273a
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/scripts/img2img.py
@@ -0,0 +1,279 @@
+"""make variations of input image"""
+
+import argparse, os
+import PIL
+import torch
+import numpy as np
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from itertools import islice
+from einops import rearrange, repeat
+from torchvision.utils import make_grid
+from torch import autocast
+from contextlib import nullcontext
+from pytorch_lightning import seed_everything
+from imwatermark import WatermarkEncoder
+
+
+from scripts.txt2img import put_watermark
+from ldm.util import instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def load_model_from_config(config, ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+
+def load_img(path):
+ image = Image.open(path).convert("RGB")
+ w, h = image.size
+ print(f"loaded input image of size ({w}, {h}) from {path}")
+ w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2. * image - 1.
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ nargs="?",
+ default="a painting of a virus monster playing guitar",
+ help="the prompt to render"
+ )
+
+ parser.add_argument(
+ "--init-img",
+ type=str,
+ nargs="?",
+ help="path to the input image"
+ )
+
+ parser.add_argument(
+ "--outdir",
+ type=str,
+ nargs="?",
+ help="dir to write results to",
+ default="outputs/img2img-samples"
+ )
+
+ parser.add_argument(
+ "--ddim_steps",
+ type=int,
+ default=50,
+ help="number of ddim sampling steps",
+ )
+
+ parser.add_argument(
+ "--fixed_code",
+ action='store_true',
+ help="if enabled, uses the same starting code across all samples ",
+ )
+
+ parser.add_argument(
+ "--ddim_eta",
+ type=float,
+ default=0.0,
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
+ )
+ parser.add_argument(
+ "--n_iter",
+ type=int,
+ default=1,
+ help="sample this often",
+ )
+
+ parser.add_argument(
+ "--C",
+ type=int,
+ default=4,
+ help="latent channels",
+ )
+ parser.add_argument(
+ "--f",
+ type=int,
+ default=8,
+ help="downsampling factor, most often 8 or 16",
+ )
+
+ parser.add_argument(
+ "--n_samples",
+ type=int,
+ default=2,
+ help="how many samples to produce for each given prompt. A.k.a batch size",
+ )
+
+ parser.add_argument(
+ "--n_rows",
+ type=int,
+ default=0,
+ help="rows in the grid (default: n_samples)",
+ )
+
+ parser.add_argument(
+ "--scale",
+ type=float,
+ default=9.0,
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
+ )
+
+ parser.add_argument(
+ "--strength",
+ type=float,
+ default=0.8,
+ help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
+ )
+
+ parser.add_argument(
+ "--from-file",
+ type=str,
+ help="if specified, load prompts from this file",
+ )
+ parser.add_argument(
+ "--config",
+ type=str,
+ default="configs/stable-diffusion/v2-inference.yaml",
+ help="path to config which constructs model",
+ )
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ help="path to checkpoint of model",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="the seed (for reproducible sampling)",
+ )
+ parser.add_argument(
+ "--precision",
+ type=str,
+ help="evaluate at this precision",
+ choices=["full", "autocast"],
+ default="autocast"
+ )
+
+ opt = parser.parse_args()
+ seed_everything(opt.seed)
+
+ config = OmegaConf.load(f"{opt.config}")
+ model = load_model_from_config(config, f"{opt.ckpt}")
+
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+
+ sampler = DDIMSampler(model)
+
+ os.makedirs(opt.outdir, exist_ok=True)
+ outpath = opt.outdir
+
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
+ wm = "SDV2"
+ wm_encoder = WatermarkEncoder()
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
+
+ batch_size = opt.n_samples
+ n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
+ if not opt.from_file:
+ prompt = opt.prompt
+ assert prompt is not None
+ data = [batch_size * [prompt]]
+
+ else:
+ print(f"reading prompts from {opt.from_file}")
+ with open(opt.from_file, "r") as f:
+ data = f.read().splitlines()
+ data = list(chunk(data, batch_size))
+
+ sample_path = os.path.join(outpath, "samples")
+ os.makedirs(sample_path, exist_ok=True)
+ base_count = len(os.listdir(sample_path))
+ grid_count = len(os.listdir(outpath)) - 1
+
+ assert os.path.isfile(opt.init_img)
+ init_image = load_img(opt.init_img).to(device)
+ init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
+
+ sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
+
+ assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
+ t_enc = int(opt.strength * opt.ddim_steps)
+ print(f"target t_enc is {t_enc} steps")
+
+ precision_scope = autocast if opt.precision == "autocast" else nullcontext
+ with torch.no_grad():
+ with precision_scope("cuda"):
+ with model.ema_scope():
+ all_samples = list()
+ for n in trange(opt.n_iter, desc="Sampling"):
+ for prompts in tqdm(data, desc="data"):
+ uc = None
+ if opt.scale != 1.0:
+ uc = model.get_learned_conditioning(batch_size * [""])
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+ c = model.get_learned_conditioning(prompts)
+
+ # encode (scaled latent)
+ z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
+ # decode it
+ samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
+ unconditional_conditioning=uc, )
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for x_sample in x_samples:
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ img = Image.fromarray(x_sample.astype(np.uint8))
+ img = put_watermark(img, wm_encoder)
+ img.save(os.path.join(sample_path, f"{base_count:05}.png"))
+ base_count += 1
+ all_samples.append(x_samples)
+
+ # additionally, save as grid
+ grid = torch.stack(all_samples, 0)
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
+ grid = make_grid(grid, nrow=n_rows)
+
+ # to image
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
+ grid = Image.fromarray(grid.astype(np.uint8))
+ grid = put_watermark(grid, wm_encoder)
+ grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
+ grid_count += 1
+
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/repositories/stable-diffusion-stability-ai/scripts/streamlit/depth2img.py b/repositories/stable-diffusion-stability-ai/scripts/streamlit/depth2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f80223405a26ded02964513293b8b3316c34344
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/scripts/streamlit/depth2img.py
@@ -0,0 +1,157 @@
+import sys
+import torch
+import numpy as np
+import streamlit as st
+from PIL import Image
+from omegaconf import OmegaConf
+from einops import repeat, rearrange
+from pytorch_lightning import seed_everything
+from imwatermark import WatermarkEncoder
+
+from scripts.txt2img import put_watermark
+from ldm.util import instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.data.util import AddMiDaS
+
+torch.set_grad_enabled(False)
+
+
+@st.cache(allow_output_mutation=True)
+def initialize_model(config, ckpt):
+ config = OmegaConf.load(config)
+ model = instantiate_from_config(config.model)
+ model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
+
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+ sampler = DDIMSampler(model)
+ return sampler
+
+
+def make_batch_sd(
+ image,
+ txt,
+ device,
+ num_samples=1,
+ model_type="dpt_hybrid"
+):
+ image = np.array(image.convert("RGB"))
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
+ midas_trafo = AddMiDaS(model_type=model_type)
+ batch = {
+ "jpg": image,
+ "txt": num_samples * [txt],
+ }
+ batch = midas_trafo(batch)
+ batch["jpg"] = rearrange(batch["jpg"], 'h w c -> 1 c h w')
+ batch["jpg"] = repeat(batch["jpg"].to(device=device), "1 ... -> n ...", n=num_samples)
+ batch["midas_in"] = repeat(torch.from_numpy(batch["midas_in"][None, ...]).to(device=device), "1 ... -> n ...", n=num_samples)
+ return batch
+
+
+def paint(sampler, image, prompt, t_enc, seed, scale, num_samples=1, callback=None,
+ do_full_sample=False):
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = sampler.model
+ seed_everything(seed)
+
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
+ wm = "SDV2"
+ wm_encoder = WatermarkEncoder()
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
+
+ with torch.no_grad(),\
+ torch.autocast("cuda"):
+ batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
+ z = model.get_first_stage_encoding(model.encode_first_stage(batch[model.first_stage_key])) # move to latent space
+ c = model.cond_stage_model.encode(batch["txt"])
+ c_cat = list()
+ for ck in model.concat_keys:
+ cc = batch[ck]
+ cc = model.depth_model(cc)
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+ keepdim=True)
+ display_depth = (cc - depth_min) / (depth_max - depth_min)
+ st.image(Image.fromarray((display_depth[0, 0, ...].cpu().numpy() * 255.).astype(np.uint8)))
+ cc = torch.nn.functional.interpolate(
+ cc,
+ size=z.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+ keepdim=True)
+ cc = 2. * (cc - depth_min) / (depth_max - depth_min) - 1.
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ # cond
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
+
+ # uncond cond
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
+ if not do_full_sample:
+ # encode (scaled latent)
+ z_enc = sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
+ else:
+ z_enc = torch.randn_like(z)
+ # decode it
+ samples = sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale,
+ unconditional_conditioning=uc_full, callback=callback)
+ x_samples_ddim = model.decode_first_stage(samples)
+ result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
+ return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
+
+
+def run():
+ st.title("Stable Diffusion Depth2Img")
+ # run via streamlit run scripts/demo/depth2img.py
+ sampler = initialize_model(sys.argv[1], sys.argv[2])
+
+ image = st.file_uploader("Image", ["jpg", "png"])
+ if image:
+ image = Image.open(image)
+ w, h = image.size
+ st.text(f"loaded input image of size ({w}, {h})")
+ width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
+ image = image.resize((width, height))
+ st.text(f"resized input image to size ({width}, {height} (w, h))")
+ st.image(image)
+
+ prompt = st.text_input("Prompt")
+
+ seed = st.number_input("Seed", min_value=0, max_value=1000000, value=0)
+ num_samples = st.number_input("Number of Samples", min_value=1, max_value=64, value=1)
+ scale = st.slider("Scale", min_value=0.1, max_value=30.0, value=9.0, step=0.1)
+ steps = st.slider("DDIM Steps", min_value=0, max_value=50, value=50, step=1)
+ strength = st.slider("Strength", min_value=0., max_value=1., value=0.9)
+
+ t_progress = st.progress(0)
+ def t_callback(t):
+ t_progress.progress(min((t + 1) / t_enc, 1.))
+
+ assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
+ do_full_sample = strength == 1.
+ t_enc = min(int(strength * steps), steps-1)
+ sampler.make_schedule(steps, ddim_eta=0., verbose=True)
+ if st.button("Sample"):
+ result = paint(
+ sampler=sampler,
+ image=image,
+ prompt=prompt,
+ t_enc=t_enc,
+ seed=seed,
+ scale=scale,
+ num_samples=num_samples,
+ callback=t_callback,
+ do_full_sample=do_full_sample,
+ )
+ st.write("Result")
+ for image in result:
+ st.image(image, output_format='PNG')
+
+
+if __name__ == "__main__":
+ run()
diff --git a/repositories/stable-diffusion-stability-ai/scripts/streamlit/inpainting.py b/repositories/stable-diffusion-stability-ai/scripts/streamlit/inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..c35772f063da8bdf0091e3931dbe12b1869dd11a
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/scripts/streamlit/inpainting.py
@@ -0,0 +1,195 @@
+import sys
+import cv2
+import torch
+import numpy as np
+import streamlit as st
+from PIL import Image
+from omegaconf import OmegaConf
+from einops import repeat
+from streamlit_drawable_canvas import st_canvas
+from imwatermark import WatermarkEncoder
+
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.util import instantiate_from_config
+
+
+torch.set_grad_enabled(False)
+
+
+def put_watermark(img, wm_encoder=None):
+ if wm_encoder is not None:
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+ img = wm_encoder.encode(img, 'dwtDct')
+ img = Image.fromarray(img[:, :, ::-1])
+ return img
+
+
+@st.cache(allow_output_mutation=True)
+def initialize_model(config, ckpt):
+ config = OmegaConf.load(config)
+ model = instantiate_from_config(config.model)
+
+ model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
+
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+ sampler = DDIMSampler(model)
+
+ return sampler
+
+
+def make_batch_sd(
+ image,
+ mask,
+ txt,
+ device,
+ num_samples=1):
+ image = np.array(image.convert("RGB"))
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = np.array(mask.convert("L"))
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ batch = {
+ "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
+ "txt": num_samples * [txt],
+ "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
+ "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
+ }
+ return batch
+
+
+def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512, eta=1.):
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = sampler.model
+
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
+ wm = "SDV2"
+ wm_encoder = WatermarkEncoder()
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
+
+ prng = np.random.RandomState(seed)
+ start_code = prng.randn(num_samples, 4, h // 8, w // 8)
+ start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
+
+ with torch.no_grad(), \
+ torch.autocast("cuda"):
+ batch = make_batch_sd(image, mask, txt=prompt, device=device, num_samples=num_samples)
+
+ c = model.cond_stage_model.encode(batch["txt"])
+
+ c_cat = list()
+ for ck in model.concat_keys:
+ cc = batch[ck].float()
+ if ck != model.masked_image_key:
+ bchw = [num_samples, 4, h // 8, w // 8]
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+ else:
+ cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+
+ # cond
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
+
+ # uncond cond
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
+
+ shape = [model.channels, h // 8, w // 8]
+ samples_cfg, intermediates = sampler.sample(
+ ddim_steps,
+ num_samples,
+ shape,
+ cond,
+ verbose=False,
+ eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=uc_full,
+ x_T=start_code,
+ )
+ x_samples_ddim = model.decode_first_stage(samples_cfg)
+
+ result = torch.clamp((x_samples_ddim + 1.0) / 2.0,
+ min=0.0, max=1.0)
+
+ result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
+ return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
+
+
+def run():
+ st.title("Stable Diffusion Inpainting")
+
+ sampler = initialize_model(sys.argv[1], sys.argv[2])
+
+ image = st.file_uploader("Image", ["jpg", "png"])
+ if image:
+ image = Image.open(image)
+ w, h = image.size
+ print(f"loaded input image of size ({w}, {h})")
+ width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
+ image = image.resize((width, height))
+
+ prompt = st.text_input("Prompt")
+
+ seed = st.number_input("Seed", min_value=0, max_value=1000000, value=0)
+ num_samples = st.number_input("Number of Samples", min_value=1, max_value=64, value=1)
+ scale = st.slider("Scale", min_value=0.1, max_value=30.0, value=10., step=0.1)
+ ddim_steps = st.slider("DDIM Steps", min_value=0, max_value=50, value=50, step=1)
+ eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
+
+ fill_color = "rgba(255, 255, 255, 0.0)"
+ stroke_width = st.number_input("Brush Size",
+ value=64,
+ min_value=1,
+ max_value=100)
+ stroke_color = "rgba(255, 255, 255, 1.0)"
+ bg_color = "rgba(0, 0, 0, 1.0)"
+ drawing_mode = "freedraw"
+
+ st.write("Canvas")
+ st.caption(
+ "Draw a mask to inpaint, then click the 'Send to Streamlit' button (bottom left, with an arrow on it).")
+ canvas_result = st_canvas(
+ fill_color=fill_color,
+ stroke_width=stroke_width,
+ stroke_color=stroke_color,
+ background_color=bg_color,
+ background_image=image,
+ update_streamlit=False,
+ height=height,
+ width=width,
+ drawing_mode=drawing_mode,
+ key="canvas",
+ )
+ if canvas_result:
+ mask = canvas_result.image_data
+ mask = mask[:, :, -1] > 0
+ if mask.sum() > 0:
+ mask = Image.fromarray(mask)
+
+ result = inpaint(
+ sampler=sampler,
+ image=image,
+ mask=mask,
+ prompt=prompt,
+ seed=seed,
+ scale=scale,
+ ddim_steps=ddim_steps,
+ num_samples=num_samples,
+ h=height, w=width, eta=eta
+ )
+ st.write("Inpainted")
+ for image in result:
+ st.image(image, output_format='PNG')
+
+
+if __name__ == "__main__":
+ run()
\ No newline at end of file
diff --git a/repositories/stable-diffusion-stability-ai/scripts/streamlit/superresolution.py b/repositories/stable-diffusion-stability-ai/scripts/streamlit/superresolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1172b02ea8141781118d4536ba28de0f24404a1
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/scripts/streamlit/superresolution.py
@@ -0,0 +1,170 @@
+import sys
+import torch
+import numpy as np
+import streamlit as st
+from PIL import Image
+from omegaconf import OmegaConf
+from einops import repeat, rearrange
+from pytorch_lightning import seed_everything
+from imwatermark import WatermarkEncoder
+
+from scripts.txt2img import put_watermark
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentUpscaleFinetuneDiffusion
+from ldm.util import exists, instantiate_from_config
+
+
+torch.set_grad_enabled(False)
+
+
+@st.cache(allow_output_mutation=True)
+def initialize_model(config, ckpt):
+ config = OmegaConf.load(config)
+ model = instantiate_from_config(config.model)
+ model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False)
+
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+ sampler = DDIMSampler(model)
+ return sampler
+
+
+def make_batch_sd(
+ image,
+ txt,
+ device,
+ num_samples=1,
+):
+ image = np.array(image.convert("RGB"))
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+ batch = {
+ "lr": rearrange(image, 'h w c -> 1 c h w'),
+ "txt": num_samples * [txt],
+ }
+ batch["lr"] = repeat(batch["lr"].to(device=device), "1 ... -> n ...", n=num_samples)
+ return batch
+
+
+def make_noise_augmentation(model, batch, noise_level=None):
+ x_low = batch[model.low_scale_key]
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
+ x_aug, noise_level = model.low_scale_model(x_low, noise_level)
+ return x_aug, noise_level
+
+
+def paint(sampler, image, prompt, seed, scale, h, w, steps, num_samples=1, callback=None, eta=0., noise_level=None):
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = sampler.model
+ seed_everything(seed)
+ prng = np.random.RandomState(seed)
+ start_code = prng.randn(num_samples, model.channels, h , w)
+ start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
+
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
+ wm = "SDV2"
+ wm_encoder = WatermarkEncoder()
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
+ with torch.no_grad(),\
+ torch.autocast("cuda"):
+ batch = make_batch_sd(image, txt=prompt, device=device, num_samples=num_samples)
+ c = model.cond_stage_model.encode(batch["txt"])
+ c_cat = list()
+ if isinstance(model, LatentUpscaleFinetuneDiffusion):
+ for ck in model.concat_keys:
+ cc = batch[ck]
+ if exists(model.reshuffle_patch_size):
+ assert isinstance(model.reshuffle_patch_size, int)
+ cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
+ p1=model.reshuffle_patch_size, p2=model.reshuffle_patch_size)
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ # cond
+ cond = {"c_concat": [c_cat], "c_crossattn": [c]}
+ # uncond cond
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
+ elif isinstance(model, LatentUpscaleDiffusion):
+ x_augment, noise_level = make_noise_augmentation(model, batch, noise_level)
+ cond = {"c_concat": [x_augment], "c_crossattn": [c], "c_adm": noise_level}
+ # uncond cond
+ uc_cross = model.get_unconditional_conditioning(num_samples, "")
+ uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
+ else:
+ raise NotImplementedError()
+
+ shape = [model.channels, h, w]
+ samples, intermediates = sampler.sample(
+ steps,
+ num_samples,
+ shape,
+ cond,
+ verbose=False,
+ eta=eta,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=uc_full,
+ x_T=start_code,
+ callback=callback
+ )
+ with torch.no_grad():
+ x_samples_ddim = model.decode_first_stage(samples)
+ result = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255
+ st.text(f"upscaled image shape: {result.shape}")
+ return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result]
+
+
+def run():
+ st.title("Stable Diffusion Upscaling")
+ # run via streamlit run scripts/demo/depth2img.py
+ sampler = initialize_model(sys.argv[1], sys.argv[2])
+
+ image = st.file_uploader("Image", ["jpg", "png"])
+ if image:
+ image = Image.open(image)
+ w, h = image.size
+ st.text(f"loaded input image of size ({w}, {h})")
+ width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
+ image = image.resize((width, height))
+ st.text(f"resized input image to size ({width}, {height} (w, h))")
+ st.image(image)
+
+ st.write(f"\n Tip: Add a description of the object that should be upscaled, e.g.: 'a professional photograph of a cat'")
+ prompt = st.text_input("Prompt", "a high quality professional photograph")
+
+ seed = st.number_input("Seed", min_value=0, max_value=1000000, value=0)
+ num_samples = st.number_input("Number of Samples", min_value=1, max_value=64, value=1)
+ scale = st.slider("Scale", min_value=0.1, max_value=30.0, value=9.0, step=0.1)
+ steps = st.slider("DDIM Steps", min_value=2, max_value=250, value=50, step=1)
+ eta = st.sidebar.number_input("eta (DDIM)", value=0., min_value=0., max_value=1.)
+
+ noise_level = None
+ if isinstance(sampler.model, LatentUpscaleDiffusion):
+ # TODO: make this work for all models
+ noise_level = st.sidebar.number_input("Noise Augmentation", min_value=0, max_value=350, value=20)
+ noise_level = torch.Tensor(num_samples * [noise_level]).to(sampler.model.device).long()
+
+ t_progress = st.progress(0)
+ def t_callback(t):
+ t_progress.progress(min((t + 1) / steps, 1.))
+
+ sampler.make_schedule(steps, ddim_eta=eta, verbose=True)
+ if st.button("Sample"):
+ result = paint(
+ sampler=sampler,
+ image=image,
+ prompt=prompt,
+ seed=seed,
+ scale=scale,
+ h=height, w=width, steps=steps,
+ num_samples=num_samples,
+ callback=t_callback,
+ noise_level=noise_level,
+ eta=eta
+ )
+ st.write("Result")
+ for image in result:
+ st.image(image, output_format='PNG')
+
+
+if __name__ == "__main__":
+ run()
diff --git a/repositories/stable-diffusion-stability-ai/scripts/tests/test_watermark.py b/repositories/stable-diffusion-stability-ai/scripts/tests/test_watermark.py
new file mode 100644
index 0000000000000000000000000000000000000000..f93f8a6e70763c0e284157bc8225827520b2f5ef
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/scripts/tests/test_watermark.py
@@ -0,0 +1,18 @@
+import cv2
+import fire
+from imwatermark import WatermarkDecoder
+
+
+def testit(img_path):
+ bgr = cv2.imread(img_path)
+ decoder = WatermarkDecoder('bytes', 136)
+ watermark = decoder.decode(bgr, 'dwtDct')
+ try:
+ dec = watermark.decode('utf-8')
+ except:
+ dec = "null"
+ print(dec)
+
+
+if __name__ == "__main__":
+ fire.Fire(testit)
\ No newline at end of file
diff --git a/repositories/stable-diffusion-stability-ai/scripts/txt2img.py b/repositories/stable-diffusion-stability-ai/scripts/txt2img.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ed42a3cd87347998e947362e8845f28bf580fdd
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/scripts/txt2img.py
@@ -0,0 +1,289 @@
+import argparse, os
+import cv2
+import torch
+import numpy as np
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from itertools import islice
+from einops import rearrange
+from torchvision.utils import make_grid
+from pytorch_lightning import seed_everything
+from torch import autocast
+from contextlib import nullcontext
+from imwatermark import WatermarkEncoder
+
+from ldm.util import instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.models.diffusion.dpm_solver import DPMSolverSampler
+
+torch.set_grad_enabled(False)
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def load_model_from_config(config, ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ nargs="?",
+ default="a professional photograph of an astronaut riding a triceratops",
+ help="the prompt to render"
+ )
+ parser.add_argument(
+ "--outdir",
+ type=str,
+ nargs="?",
+ help="dir to write results to",
+ default="outputs/txt2img-samples"
+ )
+ parser.add_argument(
+ "--steps",
+ type=int,
+ default=50,
+ help="number of ddim sampling steps",
+ )
+ parser.add_argument(
+ "--plms",
+ action='store_true',
+ help="use plms sampling",
+ )
+ parser.add_argument(
+ "--dpm",
+ action='store_true',
+ help="use DPM (2) sampler",
+ )
+ parser.add_argument(
+ "--fixed_code",
+ action='store_true',
+ help="if enabled, uses the same starting code across all samples ",
+ )
+ parser.add_argument(
+ "--ddim_eta",
+ type=float,
+ default=0.0,
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
+ )
+ parser.add_argument(
+ "--n_iter",
+ type=int,
+ default=3,
+ help="sample this often",
+ )
+ parser.add_argument(
+ "--H",
+ type=int,
+ default=512,
+ help="image height, in pixel space",
+ )
+ parser.add_argument(
+ "--W",
+ type=int,
+ default=512,
+ help="image width, in pixel space",
+ )
+ parser.add_argument(
+ "--C",
+ type=int,
+ default=4,
+ help="latent channels",
+ )
+ parser.add_argument(
+ "--f",
+ type=int,
+ default=8,
+ help="downsampling factor, most often 8 or 16",
+ )
+ parser.add_argument(
+ "--n_samples",
+ type=int,
+ default=3,
+ help="how many samples to produce for each given prompt. A.k.a batch size",
+ )
+ parser.add_argument(
+ "--n_rows",
+ type=int,
+ default=0,
+ help="rows in the grid (default: n_samples)",
+ )
+ parser.add_argument(
+ "--scale",
+ type=float,
+ default=9.0,
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
+ )
+ parser.add_argument(
+ "--from-file",
+ type=str,
+ help="if specified, load prompts from this file, separated by newlines",
+ )
+ parser.add_argument(
+ "--config",
+ type=str,
+ default="configs/stable-diffusion/v2-inference.yaml",
+ help="path to config which constructs model",
+ )
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ help="path to checkpoint of model",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="the seed (for reproducible sampling)",
+ )
+ parser.add_argument(
+ "--precision",
+ type=str,
+ help="evaluate at this precision",
+ choices=["full", "autocast"],
+ default="autocast"
+ )
+ parser.add_argument(
+ "--repeat",
+ type=int,
+ default=1,
+ help="repeat each prompt in file this often",
+ )
+ opt = parser.parse_args()
+ return opt
+
+
+def put_watermark(img, wm_encoder=None):
+ if wm_encoder is not None:
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
+ img = wm_encoder.encode(img, 'dwtDct')
+ img = Image.fromarray(img[:, :, ::-1])
+ return img
+
+
+def main(opt):
+ seed_everything(opt.seed)
+
+ config = OmegaConf.load(f"{opt.config}")
+ model = load_model_from_config(config, f"{opt.ckpt}")
+
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+
+ if opt.plms:
+ sampler = PLMSSampler(model)
+ elif opt.dpm:
+ sampler = DPMSolverSampler(model)
+ else:
+ sampler = DDIMSampler(model)
+
+ os.makedirs(opt.outdir, exist_ok=True)
+ outpath = opt.outdir
+
+ print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
+ wm = "SDV2"
+ wm_encoder = WatermarkEncoder()
+ wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
+
+ batch_size = opt.n_samples
+ n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
+ if not opt.from_file:
+ prompt = opt.prompt
+ assert prompt is not None
+ data = [batch_size * [prompt]]
+
+ else:
+ print(f"reading prompts from {opt.from_file}")
+ with open(opt.from_file, "r") as f:
+ data = f.read().splitlines()
+ data = [p for p in data for i in range(opt.repeat)]
+ data = list(chunk(data, batch_size))
+
+ sample_path = os.path.join(outpath, "samples")
+ os.makedirs(sample_path, exist_ok=True)
+ sample_count = 0
+ base_count = len(os.listdir(sample_path))
+ grid_count = len(os.listdir(outpath)) - 1
+
+ start_code = None
+ if opt.fixed_code:
+ start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
+
+ precision_scope = autocast if opt.precision == "autocast" else nullcontext
+ with torch.no_grad(), \
+ precision_scope("cuda"), \
+ model.ema_scope():
+ all_samples = list()
+ for n in trange(opt.n_iter, desc="Sampling"):
+ for prompts in tqdm(data, desc="data"):
+ uc = None
+ if opt.scale != 1.0:
+ uc = model.get_learned_conditioning(batch_size * [""])
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+ c = model.get_learned_conditioning(prompts)
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
+ samples, _ = sampler.sample(S=opt.steps,
+ conditioning=c,
+ batch_size=opt.n_samples,
+ shape=shape,
+ verbose=False,
+ unconditional_guidance_scale=opt.scale,
+ unconditional_conditioning=uc,
+ eta=opt.ddim_eta,
+ x_T=start_code)
+
+ x_samples = model.decode_first_stage(samples)
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for x_sample in x_samples:
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ img = Image.fromarray(x_sample.astype(np.uint8))
+ img = put_watermark(img, wm_encoder)
+ img.save(os.path.join(sample_path, f"{base_count:05}.png"))
+ base_count += 1
+ sample_count += 1
+
+ all_samples.append(x_samples)
+
+ # additionally, save as grid
+ grid = torch.stack(all_samples, 0)
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
+ grid = make_grid(grid, nrow=n_rows)
+
+ # to image
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
+ grid = Image.fromarray(grid.astype(np.uint8))
+ grid = put_watermark(grid, wm_encoder)
+ grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
+ grid_count += 1
+
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
+ f" \nEnjoy.")
+
+
+if __name__ == "__main__":
+ opt = parse_args()
+ main(opt)
diff --git a/repositories/stable-diffusion-stability-ai/setup.py b/repositories/stable-diffusion-stability-ai/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..00f5b4d874f0f19ece54fac2dd50b39774b86c5b
--- /dev/null
+++ b/repositories/stable-diffusion-stability-ai/setup.py
@@ -0,0 +1,13 @@
+from setuptools import setup, find_packages
+
+setup(
+ name='stable-diffusion',
+ version='0.0.1',
+ description='',
+ packages=find_packages(),
+ install_requires=[
+ 'torch',
+ 'numpy',
+ 'tqdm',
+ ],
+)
\ No newline at end of file
diff --git a/repositories/taming-transformers/License.txt b/repositories/taming-transformers/License.txt
new file mode 100644
index 0000000000000000000000000000000000000000..57fb4153bafcd64b60377ba0ba2c79b7530efc1e
--- /dev/null
+++ b/repositories/taming-transformers/License.txt
@@ -0,0 +1,19 @@
+Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
+OR OTHER DEALINGS IN THE SOFTWARE./
diff --git a/repositories/taming-transformers/README.md b/repositories/taming-transformers/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d295fbf75703e6cd285330432785b8cdea072ba7
--- /dev/null
+++ b/repositories/taming-transformers/README.md
@@ -0,0 +1,410 @@
+# Taming Transformers for High-Resolution Image Synthesis
+##### CVPR 2021 (Oral)
+![teaser](assets/mountain.jpeg)
+
+[**Taming Transformers for High-Resolution Image Synthesis**](https://compvis.github.io/taming-transformers/)
+[Patrick Esser](https://github.com/pesser)\*,
+[Robin Rombach](https://github.com/rromb)\*,
+[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
+\* equal contribution
+
+**tl;dr** We combine the efficiancy of convolutional approaches with the expressivity of transformers by introducing a convolutional VQGAN, which learns a codebook of context-rich visual parts, whose composition is modeled with an autoregressive transformer.
+
+![teaser](assets/teaser.png)
+[arXiv](https://arxiv.org/abs/2012.09841) | [BibTeX](#bibtex) | [Project Page](https://compvis.github.io/taming-transformers/)
+
+
+### News
+#### 2022
+- More pretrained VQGANs (e.g. a f8-model with only 256 codebook entries) are available in our new work on [Latent Diffusion Models](https://github.com/CompVis/latent-diffusion).
+- Added scene synthesis models as proposed in the paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458), see [this section](#scene-image-synthesis).
+#### 2021
+- Thanks to [rom1504](https://github.com/rom1504) it is now easy to [train a VQGAN on your own datasets](#training-on-custom-data).
+- Included a bugfix for the quantizer. For backward compatibility it is
+ disabled by default (which corresponds to always training with `beta=1.0`).
+ Use `legacy=False` in the quantizer config to enable it.
+ Thanks [richcmwang](https://github.com/richcmwang) and [wcshin-git](https://github.com/wcshin-git)!
+- Our paper received an update: See https://arxiv.org/abs/2012.09841v3 and the corresponding changelog.
+- Added a pretrained, [1.4B transformer model](https://k00.fr/s511rwcv) trained for class-conditional ImageNet synthesis, which obtains state-of-the-art FID scores among autoregressive approaches and outperforms BigGAN.
+- Added pretrained, unconditional models on [FFHQ](https://k00.fr/yndvfu95) and [CelebA-HQ](https://k00.fr/2xkmielf).
+- Added accelerated sampling via caching of keys/values in the self-attention operation, used in `scripts/sample_fast.py`.
+- Added a checkpoint of a [VQGAN](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) trained with f8 compression and Gumbel-Quantization.
+ See also our updated [reconstruction notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb).
+- We added a [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb) which compares two VQGANs and OpenAI's [DALL-E](https://github.com/openai/DALL-E). See also [this section](#more-resources).
+- We now include an overview of pretrained models in [Tab.1](#overview-of-pretrained-models). We added models for [COCO](#coco) and [ADE20k](#ade20k).
+- The streamlit demo now supports image completions.
+- We now include a couple of examples from the D-RIN dataset so you can run the
+ [D-RIN demo](#d-rin) without preparing the dataset first.
+- You can now jump right into sampling with our [Colab quickstart notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb).
+
+## Requirements
+A suitable [conda](https://conda.io/) environment named `taming` can be created
+and activated with:
+
+```
+conda env create -f environment.yaml
+conda activate taming
+```
+## Overview of pretrained models
+The following table provides an overview of all models that are currently available.
+FID scores were evaluated using [torch-fidelity](https://github.com/toshas/torch-fidelity).
+For reference, we also include a link to the recently released autoencoder of the [DALL-E](https://github.com/openai/DALL-E) model.
+See the corresponding [colab
+notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb)
+for a comparison and discussion of reconstruction capabilities.
+
+| Dataset | FID vs train | FID vs val | Link | Samples (256x256) | Comments
+| ------------- | ------------- | ------------- |------------- | ------------- |------------- |
+| FFHQ (f=16) | 9.6 | -- | [ffhq_transformer](https://k00.fr/yndvfu95) | [ffhq_samples](https://k00.fr/j626x093) |
+| CelebA-HQ (f=16) | 10.2 | -- | [celebahq_transformer](https://k00.fr/2xkmielf) | [celebahq_samples](https://k00.fr/j626x093) |
+| ADE20K (f=16) | -- | 35.5 | [ade20k_transformer](https://k00.fr/ot46cksa) | [ade20k_samples.zip](https://heibox.uni-heidelberg.de/f/70bb78cbaf844501b8fb/) [2k] | evaluated on val split (2k images)
+| COCO-Stuff (f=16) | -- | 20.4 | [coco_transformer](https://k00.fr/2zz6i2ce) | [coco_samples.zip](https://heibox.uni-heidelberg.de/f/a395a9be612f4a7a8054/) [5k] | evaluated on val split (5k images)
+| ImageNet (cIN) (f=16) | 15.98/15.78/6.59/5.88/5.20 | -- | [cin_transformer](https://k00.fr/s511rwcv) | [cin_samples](https://k00.fr/j626x093) | different decoding hyperparameters |
+| | | | || |
+| FacesHQ (f=16) | -- | -- | [faceshq_transformer](https://k00.fr/qqfl2do8)
+| S-FLCKR (f=16) | -- | -- | [sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
+| D-RIN (f=16) | -- | -- | [drin_transformer](https://k00.fr/39jcugc5)
+| | | | | || |
+| VQGAN ImageNet (f=16), 1024 | 10.54 | 7.94 | [vqgan_imagenet_f16_1024](https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
+| VQGAN ImageNet (f=16), 16384 | 7.41 | 4.98 |[vqgan_imagenet_f16_16384](https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
+| VQGAN OpenImages (f=8), 256 | -- | 1.49 |https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion).
+| VQGAN OpenImages (f=8), 16384 | -- | 1.14 |https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion)
+| VQGAN OpenImages (f=8), 8192, GumbelQuantization | 3.24 | 1.49 |[vqgan_gumbel_f8](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) | --- | Reconstruction-FIDs.
+| | | | | || |
+| DALL-E dVAE (f=8), 8192, GumbelQuantization | 33.88 | 32.01 | https://github.com/openai/DALL-E | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
+
+
+## Running pretrained models
+
+The commands below will start a streamlit demo which supports sampling at
+different resolutions and image completions. To run a non-interactive version
+of the sampling process, replace `streamlit run scripts/sample_conditional.py --`
+by `python scripts/make_samples.py --outdir ` and
+keep the remaining command line arguments.
+
+To sample from unconditional or class-conditional models,
+run `python scripts/sample_fast.py -r `.
+We describe below how to use this script to sample from the ImageNet, FFHQ, and CelebA-HQ models,
+respectively.
+
+### S-FLCKR
+![teaser](assets/sunset_and_ocean.jpg)
+
+You can also [run this model in a Colab
+notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb),
+which includes all necessary steps to start sampling.
+
+Download the
+[2020-11-09T13-31-51_sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
+folder and place it into `logs`. Then, run
+```
+streamlit run scripts/sample_conditional.py -- -r logs/2020-11-09T13-31-51_sflckr/
+```
+
+### ImageNet
+![teaser](assets/imagenet.png)
+
+Download the [2021-04-03T19-39-50_cin_transformer](https://k00.fr/s511rwcv)
+folder and place it into logs. Sampling from the class-conditional ImageNet
+model does not require any data preparation. To produce 50 samples for each of
+the 1000 classes of ImageNet, with k=600 for top-k sampling, p=0.92 for nucleus
+sampling and temperature t=1.0, run
+
+```
+python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25
+```
+
+To restrict the model to certain classes, provide them via the `--classes` argument, separated by
+commas. For example, to sample 50 *ostriches*, *border collies* and *whiskey jugs*, run
+
+```
+python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25 --classes 9,232,901
+```
+We recommended to experiment with the autoregressive decoding parameters (top-k, top-p and temperature) for best results.
+
+### FFHQ/CelebA-HQ
+
+Download the [2021-04-23T18-19-01_ffhq_transformer](https://k00.fr/yndvfu95) and
+[2021-04-23T18-11-19_celebahq_transformer](https://k00.fr/2xkmielf)
+folders and place them into logs.
+Again, sampling from these unconditional models does not require any data preparation.
+To produce 50000 samples, with k=250 for top-k sampling,
+p=1.0 for nucleus sampling and temperature t=1.0, run
+
+```
+python scripts/sample_fast.py -r logs/2021-04-23T18-19-01_ffhq_transformer/
+```
+for FFHQ and
+
+```
+python scripts/sample_fast.py -r logs/2021-04-23T18-11-19_celebahq_transformer/
+```
+to sample from the CelebA-HQ model.
+For both models it can be advantageous to vary the top-k/top-p parameters for sampling.
+
+### FacesHQ
+![teaser](assets/faceshq.jpg)
+
+Download [2020-11-13T21-41-45_faceshq_transformer](https://k00.fr/qqfl2do8) and
+place it into `logs`. Follow the data preparation steps for
+[CelebA-HQ](#celeba-hq) and [FFHQ](#ffhq). Run
+```
+streamlit run scripts/sample_conditional.py -- -r logs/2020-11-13T21-41-45_faceshq_transformer/
+```
+
+### D-RIN
+![teaser](assets/drin.jpg)
+
+Download [2020-11-20T12-54-32_drin_transformer](https://k00.fr/39jcugc5) and
+place it into `logs`. To run the demo on a couple of example depth maps
+included in the repository, run
+
+```
+streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.imagenet.DRINExamples}}}"
+```
+
+To run the demo on the complete validation set, first follow the data preparation steps for
+[ImageNet](#imagenet) and then run
+```
+streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/
+```
+
+### COCO
+Download [2021-01-20T16-04-20_coco_transformer](https://k00.fr/2zz6i2ce) and
+place it into `logs`. To run the demo on a couple of example segmentation maps
+included in the repository, run
+
+```
+streamlit run scripts/sample_conditional.py -- -r logs/2021-01-20T16-04-20_coco_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.coco.Examples}}}"
+```
+
+### ADE20k
+Download [2020-11-20T21-45-44_ade20k_transformer](https://k00.fr/ot46cksa) and
+place it into `logs`. To run the demo on a couple of example segmentation maps
+included in the repository, run
+
+```
+streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}"
+```
+
+## Scene Image Synthesis
+![teaser](assets/scene_images_samples.svg)
+Scene image generation based on bounding box conditionals as done in our CVPR2021 AI4CC workshop paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458) (see talk on [workshop page](https://visual.cs.brown.edu/workshops/aicc2021/#awards)). Supporting the datasets COCO and Open Images.
+
+### Training
+Download first-stage models [COCO-8k-VQGAN](https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/) for COCO or [COCO/Open-Images-8k-VQGAN](https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/) for Open Images.
+Change `ckpt_path` in `data/coco_scene_images_transformer.yaml` and `data/open_images_scene_images_transformer.yaml` to point to the downloaded first-stage models.
+Download the full COCO/OI datasets and adapt `data_path` in the same files, unless working with the 100 files provided for training and validation suits your needs already.
+
+Code can be run with
+`python main.py --base configs/coco_scene_images_transformer.yaml -t True --gpus 0,`
+or
+`python main.py --base configs/open_images_scene_images_transformer.yaml -t True --gpus 0,`
+
+### Sampling
+Train a model as described above or download a pre-trained model:
+ - [Open Images 1 billion parameter model](https://drive.google.com/file/d/1FEK-Z7hyWJBvFWQF50pzSK9y1W_CJEig/view?usp=sharing) available that trained 100 epochs. On 256x256 pixels, FID 41.48±0.21, SceneFID 14.60±0.15, Inception Score 18.47±0.27. The model was trained with 2d crops of images and is thus well-prepared for the task of generating high-resolution images, e.g. 512x512.
+ - [Open Images distilled version of the above model with 125 million parameters](https://drive.google.com/file/d/1xf89g0mc78J3d8Bx5YhbK4tNRNlOoYaO) allows for sampling on smaller GPUs (4 GB is enough for sampling 256x256 px images). Model was trained for 60 epochs with 10% soft loss, 90% hard loss. On 256x256 pixels, FID 43.07±0.40, SceneFID 15.93±0.19, Inception Score 17.23±0.11.
+ - [COCO 30 epochs](https://heibox.uni-heidelberg.de/f/0d0b2594e9074c7e9a33/)
+ - [COCO 60 epochs](https://drive.google.com/file/d/1bInd49g2YulTJBjU32Awyt5qnzxxG5U9/) (find model statistics for both COCO versions in `assets/coco_scene_images_training.svg`)
+
+When downloading a pre-trained model, remember to change `ckpt_path` in `configs/*project.yaml` to point to your downloaded first-stage model (see ->Training).
+
+Scene image generation can be run with
+`python scripts/make_scene_samples.py --outdir=/some/outdir -r /path/to/pretrained/model --resolution=512,512`
+
+
+## Training on custom data
+
+Training on your own dataset can be beneficial to get better tokens and hence better images for your domain.
+Those are the steps to follow to make this work:
+1. install the repo with `conda env create -f environment.yaml`, `conda activate taming` and `pip install -e .`
+1. put your .jpg files in a folder `your_folder`
+2. create 2 text files a `xx_train.txt` and `xx_test.txt` that point to the files in your training and test set respectively (for example `find $(pwd)/your_folder -name "*.jpg" > train.txt`)
+3. adapt `configs/custom_vqgan.yaml` to point to these 2 files
+4. run `python main.py --base configs/custom_vqgan.yaml -t True --gpus 0,1` to
+ train on two GPUs. Use `--gpus 0,` (with a trailing comma) to train on a single GPU.
+
+## Data Preparation
+
+### ImageNet
+The code will try to download (through [Academic
+Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
+is used. However, since ImageNet is quite large, this requires a lot of disk
+space and time. If you already have ImageNet on your disk, you can speed things
+up by putting the data into
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
+`~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
+of `train`/`validation`. It should have the following structure:
+
+```
+${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
+├── n01440764
+│ ├── n01440764_10026.JPEG
+│ ├── n01440764_10027.JPEG
+│ ├── ...
+├── n01443537
+│ ├── n01443537_10007.JPEG
+│ ├── n01443537_10014.JPEG
+│ ├── ...
+├── ...
+```
+
+If you haven't extracted the data, you can also place
+`ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
+extracted into above structure without downloading it again. Note that this
+will only happen if neither a folder
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
+if you want to force running the dataset preparation again.
+
+You will then need to prepare the depth data using
+[MiDaS](https://github.com/intel-isl/MiDaS). Create a symlink
+`data/imagenet_depth` pointing to a folder with two subfolders `train` and
+`val`, each mirroring the structure of the corresponding ImageNet folder
+described above and containing a `png` file for each of ImageNet's `JPEG`
+files. The `png` encodes `float32` depth values obtained from MiDaS as RGBA
+images. We provide the script `scripts/extract_depth.py` to generate this data.
+**Please note** that this script uses [MiDaS via PyTorch
+Hub](https://pytorch.org/hub/intelisl_midas_v2/). When we prepared the data,
+the hub provided the [MiDaS
+v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2) version, but now it
+provides a v2.1 version. We haven't tested our models with depth maps obtained
+via v2.1 and if you want to make sure that things work as expected, you must
+adjust the script to make sure it explicitly uses
+[v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2)!
+
+### CelebA-HQ
+Create a symlink `data/celebahq` pointing to a folder containing the `.npy`
+files of CelebA-HQ (instructions to obtain them can be found in the [PGGAN
+repository](https://github.com/tkarras/progressive_growing_of_gans)).
+
+### FFHQ
+Create a symlink `data/ffhq` pointing to the `images1024x1024` folder obtained
+from the [FFHQ repository](https://github.com/NVlabs/ffhq-dataset).
+
+### S-FLCKR
+Unfortunately, we are not allowed to distribute the images we collected for the
+S-FLCKR dataset and can therefore only give a description how it was produced.
+There are many resources on [collecting images from the
+web](https://github.com/adrianmrit/flickrdatasets) to get started.
+We collected sufficiently large images from [flickr](https://www.flickr.com)
+(see `data/flickr_tags.txt` for a full list of tags used to find images)
+and various [subreddits](https://www.reddit.com/r/sfwpornnetwork/wiki/network)
+(see `data/subreddits.txt` for all subreddits that were used).
+Overall, we collected 107625 images, and split them randomly into 96861
+training images and 10764 validation images. We then obtained segmentation
+masks for each image using [DeepLab v2](https://arxiv.org/abs/1606.00915)
+trained on [COCO-Stuff](https://arxiv.org/abs/1612.03716). We used a [PyTorch
+reimplementation](https://github.com/kazuto1011/deeplab-pytorch) and include an
+example script for this process in `scripts/extract_segmentation.py`.
+
+### COCO
+Create a symlink `data/coco` containing the images from the 2017 split in
+`train2017` and `val2017`, and their annotations in `annotations`. Files can be
+obtained from the [COCO webpage](https://cocodataset.org/). In addition, we use
+the [Stuff+thing PNG-style annotations on COCO 2017
+trainval](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip)
+annotations from [COCO-Stuff](https://github.com/nightrome/cocostuff), which
+should be placed under `data/cocostuffthings`.
+
+### ADE20k
+Create a symlink `data/ade20k_root` containing the contents of
+[ADEChallengeData2016.zip](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip)
+from the [MIT Scene Parsing Benchmark](http://sceneparsing.csail.mit.edu/).
+
+## Training models
+
+### FacesHQ
+
+Train a VQGAN with
+```
+python main.py --base configs/faceshq_vqgan.yaml -t True --gpus 0,
+```
+
+Then, adjust the checkpoint path of the config key
+`model.params.first_stage_config.params.ckpt_path` in
+`configs/faceshq_transformer.yaml` (or download
+[2020-11-09T13-33-36_faceshq_vqgan](https://k00.fr/uxy5usa9) and place into `logs`, which
+corresponds to the preconfigured checkpoint path), then run
+```
+python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0,
+```
+
+### D-RIN
+
+Train a VQGAN on ImageNet with
+```
+python main.py --base configs/imagenet_vqgan.yaml -t True --gpus 0,
+```
+
+or download a pretrained one from [2020-09-23T17-56-33_imagenet_vqgan](https://k00.fr/u0j2dtac)
+and place under `logs`. If you trained your own, adjust the path in the config
+key `model.params.first_stage_config.params.ckpt_path` of
+`configs/drin_transformer.yaml`.
+
+Train a VQGAN on Depth Maps of ImageNet with
+```
+python main.py --base configs/imagenetdepth_vqgan.yaml -t True --gpus 0,
+```
+
+or download a pretrained one from [2020-11-03T15-34-24_imagenetdepth_vqgan](https://k00.fr/55rlxs6i)
+and place under `logs`. If you trained your own, adjust the path in the config
+key `model.params.cond_stage_config.params.ckpt_path` of
+`configs/drin_transformer.yaml`.
+
+To train the transformer, run
+```
+python main.py --base configs/drin_transformer.yaml -t True --gpus 0,
+```
+
+## More Resources
+### Comparing Different First Stage Models
+The reconstruction and compression capabilities of different fist stage models can be analyzed in this [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb).
+In particular, the notebook compares two VQGANs with a downsampling factor of f=16 for each and codebook dimensionality of 1024 and 16384,
+a VQGAN with f=8 and 8192 codebook entries and the discrete autoencoder of OpenAI's [DALL-E](https://github.com/openai/DALL-E) (which has f=8 and 8192
+codebook entries).
+![firststages1](assets/first_stage_squirrels.png)
+![firststages2](assets/first_stage_mushrooms.png)
+
+### Other
+- A [video summary](https://www.youtube.com/watch?v=o7dqGcLDf0A&feature=emb_imp_woyt) by [Two Minute Papers](https://www.youtube.com/channel/UCbfYPyITQ-7l4upoX8nvctg).
+- A [video summary](https://www.youtube.com/watch?v=-wDSDtIAyWQ) by [Gradient Dude](https://www.youtube.com/c/GradientDude/about).
+- A [weights and biases report summarizing the paper](https://wandb.ai/ayush-thakur/taming-transformer/reports/-Overview-Taming-Transformers-for-High-Resolution-Image-Synthesis---Vmlldzo0NjEyMTY)
+by [ayulockin](https://github.com/ayulockin).
+- A [video summary](https://www.youtube.com/watch?v=JfUTd8fjtX8&feature=emb_imp_woyt) by [What's AI](https://www.youtube.com/channel/UCUzGQrN-lyyc0BWTYoJM_Sg).
+- Take a look at [ak9250's notebook](https://github.com/ak9250/taming-transformers/blob/master/tamingtransformerscolab.ipynb) if you want to run the streamlit demos on Colab.
+
+### Text-to-Image Optimization via CLIP
+VQGAN has been successfully used as an image generator guided by the [CLIP](https://github.com/openai/CLIP) model, both for pure image generation
+from scratch and image-to-image translation. We recommend the following notebooks/videos/resources:
+
+ - [Advadnouns](https://twitter.com/advadnoun/status/1389316507134357506) Patreon and corresponding LatentVision notebooks: https://www.patreon.com/patronizeme
+ - The [notebook]( https://colab.research.google.com/drive/1L8oL-vLJXVcRzCFbPwOoMkPKJ8-aYdPN) of [Rivers Have Wings](https://twitter.com/RiversHaveWings).
+ - A [video](https://www.youtube.com/watch?v=90QDe6DQXF4&t=12s) explanation by [Dot CSV](https://www.youtube.com/channel/UCy5znSnfMsDwaLlROnZ7Qbg) (in Spanish, but English subtitles are available)
+
+![txt2img](assets/birddrawnbyachild.png)
+
+Text prompt: *'A bird drawn by a child'*
+
+## Shout-outs
+Thanks to everyone who makes their code and models available. In particular,
+
+- The architecture of our VQGAN is inspired by [Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion)
+- The very hackable transformer implementation [minGPT](https://github.com/karpathy/minGPT)
+- The good ol' [PatchGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) and [Learned Perceptual Similarity (LPIPS)](https://github.com/richzhang/PerceptualSimilarity)
+
+## BibTeX
+
+```
+@misc{esser2020taming,
+ title={Taming Transformers for High-Resolution Image Synthesis},
+ author={Patrick Esser and Robin Rombach and Björn Ommer},
+ year={2020},
+ eprint={2012.09841},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
diff --git a/repositories/taming-transformers/configs/coco_cond_stage.yaml b/repositories/taming-transformers/configs/coco_cond_stage.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..18a3dde147455c281a3687b1a0b42bbbc3fb2725
--- /dev/null
+++ b/repositories/taming-transformers/configs/coco_cond_stage.yaml
@@ -0,0 +1,49 @@
+model:
+ base_learning_rate: 4.5e-06
+ target: taming.models.vqgan.VQSegmentationModel
+ params:
+ embed_dim: 256
+ n_embed: 1024
+ image_key: "segmentation"
+ n_labels: 183
+ ddconfig:
+ double_z: false
+ z_channels: 256
+ resolution: 256
+ in_channels: 183
+ out_ch: 183
+ ch: 128
+ ch_mult:
+ - 1
+ - 1
+ - 2
+ - 2
+ - 4
+ num_res_blocks: 2
+ attn_resolutions:
+ - 16
+ dropout: 0.0
+
+ lossconfig:
+ target: taming.modules.losses.segmentation.BCELossWithQuant
+ params:
+ codebook_weight: 1.0
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 12
+ train:
+ target: taming.data.coco.CocoImagesAndCaptionsTrain
+ params:
+ size: 296
+ crop_size: 256
+ onehot_segmentation: true
+ use_stuffthing: true
+ validation:
+ target: taming.data.coco.CocoImagesAndCaptionsValidation
+ params:
+ size: 256
+ crop_size: 256
+ onehot_segmentation: true
+ use_stuffthing: true
diff --git a/repositories/taming-transformers/configs/coco_scene_images_transformer.yaml b/repositories/taming-transformers/configs/coco_scene_images_transformer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a03078de708182cc175f139078f8455ca3ec8a09
--- /dev/null
+++ b/repositories/taming-transformers/configs/coco_scene_images_transformer.yaml
@@ -0,0 +1,80 @@
+model:
+ base_learning_rate: 4.5e-06
+ target: taming.models.cond_transformer.Net2NetTransformer
+ params:
+ cond_stage_key: objects_bbox
+ transformer_config:
+ target: taming.modules.transformer.mingpt.GPT
+ params:
+ vocab_size: 8192
+ block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
+ n_layer: 40
+ n_head: 16
+ n_embd: 1408
+ embd_pdrop: 0.1
+ resid_pdrop: 0.1
+ attn_pdrop: 0.1
+ first_stage_config:
+ target: taming.models.vqgan.VQModel
+ params:
+ ckpt_path: /path/to/coco_epoch117.ckpt # https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/
+ embed_dim: 256
+ n_embed: 8192
+ ddconfig:
+ double_z: false
+ z_channels: 256
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 1
+ - 2
+ - 2
+ - 4
+ num_res_blocks: 2
+ attn_resolutions:
+ - 16
+ dropout: 0.0
+ lossconfig:
+ target: taming.modules.losses.DummyLoss
+ cond_stage_config:
+ target: taming.models.dummy_cond_stage.DummyCondStage
+ params:
+ conditional_key: objects_bbox
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 6
+ train:
+ target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
+ params:
+ data_path: data/coco_annotations_100 # substitute with path to full dataset
+ split: train
+ keys: [image, objects_bbox, file_name, annotations]
+ no_tokens: 8192
+ target_image_size: 256
+ min_object_area: 0.00001
+ min_objects_per_image: 2
+ max_objects_per_image: 30
+ crop_method: random-1d
+ random_flip: true
+ use_group_parameter: true
+ encode_crop: true
+ validation:
+ target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
+ params:
+ data_path: data/coco_annotations_100 # substitute with path to full dataset
+ split: validation
+ keys: [image, objects_bbox, file_name, annotations]
+ no_tokens: 8192
+ target_image_size: 256
+ min_object_area: 0.00001
+ min_objects_per_image: 2
+ max_objects_per_image: 30
+ crop_method: center
+ random_flip: false
+ use_group_parameter: true
+ encode_crop: true
diff --git a/repositories/taming-transformers/configs/custom_vqgan.yaml b/repositories/taming-transformers/configs/custom_vqgan.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..908687f38325dfe49430692d668733a8e1598375
--- /dev/null
+++ b/repositories/taming-transformers/configs/custom_vqgan.yaml
@@ -0,0 +1,43 @@
+model:
+ base_learning_rate: 4.5e-6
+ target: taming.models.vqgan.VQModel
+ params:
+ embed_dim: 256
+ n_embed: 1024
+ ddconfig:
+ double_z: False
+ z_channels: 256
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
+ num_res_blocks: 2
+ attn_resolutions: [16]
+ dropout: 0.0
+
+ lossconfig:
+ target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
+ params:
+ disc_conditional: False
+ disc_in_channels: 3
+ disc_start: 10000
+ disc_weight: 0.8
+ codebook_weight: 1.0
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 5
+ num_workers: 8
+ train:
+ target: taming.data.custom.CustomTrain
+ params:
+ training_images_list_file: some/training.txt
+ size: 256
+ validation:
+ target: taming.data.custom.CustomTest
+ params:
+ test_images_list_file: some/test.txt
+ size: 256
+
diff --git a/repositories/taming-transformers/configs/drin_transformer.yaml b/repositories/taming-transformers/configs/drin_transformer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bead4567d2dcc3d0f1a7b8eec823df4b427cab07
--- /dev/null
+++ b/repositories/taming-transformers/configs/drin_transformer.yaml
@@ -0,0 +1,77 @@
+model:
+ base_learning_rate: 4.5e-06
+ target: taming.models.cond_transformer.Net2NetTransformer
+ params:
+ cond_stage_key: depth
+ transformer_config:
+ target: taming.modules.transformer.mingpt.GPT
+ params:
+ vocab_size: 1024
+ block_size: 512
+ n_layer: 24
+ n_head: 16
+ n_embd: 1024
+ first_stage_config:
+ target: taming.models.vqgan.VQModel
+ params:
+ ckpt_path: logs/2020-09-23T17-56-33_imagenet_vqgan/checkpoints/last.ckpt
+ embed_dim: 256
+ n_embed: 1024
+ ddconfig:
+ double_z: false
+ z_channels: 256
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 1
+ - 2
+ - 2
+ - 4
+ num_res_blocks: 2
+ attn_resolutions:
+ - 16
+ dropout: 0.0
+ lossconfig:
+ target: taming.modules.losses.DummyLoss
+ cond_stage_config:
+ target: taming.models.vqgan.VQModel
+ params:
+ ckpt_path: logs/2020-11-03T15-34-24_imagenetdepth_vqgan/checkpoints/last.ckpt
+ embed_dim: 256
+ n_embed: 1024
+ ddconfig:
+ double_z: false
+ z_channels: 256
+ resolution: 256
+ in_channels: 1
+ out_ch: 1
+ ch: 128
+ ch_mult:
+ - 1
+ - 1
+ - 2
+ - 2
+ - 4
+ num_res_blocks: 2
+ attn_resolutions:
+ - 16
+ dropout: 0.0
+ lossconfig:
+ target: taming.modules.losses.DummyLoss
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 2
+ num_workers: 8
+ train:
+ target: taming.data.imagenet.RINTrainWithDepth
+ params:
+ size: 256
+ validation:
+ target: taming.data.imagenet.RINValidationWithDepth
+ params:
+ size: 256
diff --git a/repositories/taming-transformers/configs/faceshq_transformer.yaml b/repositories/taming-transformers/configs/faceshq_transformer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b93391f9c9c41d63d28dd38dbf83615552642db3
--- /dev/null
+++ b/repositories/taming-transformers/configs/faceshq_transformer.yaml
@@ -0,0 +1,61 @@
+model:
+ base_learning_rate: 4.5e-06
+ target: taming.models.cond_transformer.Net2NetTransformer
+ params:
+ cond_stage_key: coord
+ transformer_config:
+ target: taming.modules.transformer.mingpt.GPT
+ params:
+ vocab_size: 1024
+ block_size: 512
+ n_layer: 24
+ n_head: 16
+ n_embd: 1024
+ first_stage_config:
+ target: taming.models.vqgan.VQModel
+ params:
+ ckpt_path: logs/2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt
+ embed_dim: 256
+ n_embed: 1024
+ ddconfig:
+ double_z: false
+ z_channels: 256
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 1
+ - 2
+ - 2
+ - 4
+ num_res_blocks: 2
+ attn_resolutions:
+ - 16
+ dropout: 0.0
+ lossconfig:
+ target: taming.modules.losses.DummyLoss
+ cond_stage_config:
+ target: taming.modules.misc.coord.CoordStage
+ params:
+ n_embed: 1024
+ down_factor: 16
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 2
+ num_workers: 8
+ train:
+ target: taming.data.faceshq.FacesHQTrain
+ params:
+ size: 256
+ crop_size: 256
+ coord: True
+ validation:
+ target: taming.data.faceshq.FacesHQValidation
+ params:
+ size: 256
+ crop_size: 256
+ coord: True
diff --git a/repositories/taming-transformers/configs/faceshq_vqgan.yaml b/repositories/taming-transformers/configs/faceshq_vqgan.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3960f784551bfd9caddb5b084fc592c6eed6483b
--- /dev/null
+++ b/repositories/taming-transformers/configs/faceshq_vqgan.yaml
@@ -0,0 +1,42 @@
+model:
+ base_learning_rate: 4.5e-6
+ target: taming.models.vqgan.VQModel
+ params:
+ embed_dim: 256
+ n_embed: 1024
+ ddconfig:
+ double_z: False
+ z_channels: 256
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
+ num_res_blocks: 2
+ attn_resolutions: [16]
+ dropout: 0.0
+
+ lossconfig:
+ target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
+ params:
+ disc_conditional: False
+ disc_in_channels: 3
+ disc_start: 30001
+ disc_weight: 0.8
+ codebook_weight: 1.0
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 3
+ num_workers: 8
+ train:
+ target: taming.data.faceshq.FacesHQTrain
+ params:
+ size: 256
+ crop_size: 256
+ validation:
+ target: taming.data.faceshq.FacesHQValidation
+ params:
+ size: 256
+ crop_size: 256
diff --git a/repositories/taming-transformers/configs/imagenet_vqgan.yaml b/repositories/taming-transformers/configs/imagenet_vqgan.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f6dc21ff6de9a26474fa18a0d496a4a0b9bb0837
--- /dev/null
+++ b/repositories/taming-transformers/configs/imagenet_vqgan.yaml
@@ -0,0 +1,42 @@
+model:
+ base_learning_rate: 4.5e-6
+ target: taming.models.vqgan.VQModel
+ params:
+ embed_dim: 256
+ n_embed: 1024
+ ddconfig:
+ double_z: False
+ z_channels: 256
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
+ num_res_blocks: 2
+ attn_resolutions: [16]
+ dropout: 0.0
+
+ lossconfig:
+ target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
+ params:
+ disc_conditional: False
+ disc_in_channels: 3
+ disc_start: 250001
+ disc_weight: 0.8
+ codebook_weight: 1.0
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 12
+ num_workers: 24
+ train:
+ target: taming.data.imagenet.ImageNetTrain
+ params:
+ config:
+ size: 256
+ validation:
+ target: taming.data.imagenet.ImageNetValidation
+ params:
+ config:
+ size: 256
diff --git a/repositories/taming-transformers/configs/imagenetdepth_vqgan.yaml b/repositories/taming-transformers/configs/imagenetdepth_vqgan.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..88d2f34f1c0661e350899cf4229cdf60697baf0d
--- /dev/null
+++ b/repositories/taming-transformers/configs/imagenetdepth_vqgan.yaml
@@ -0,0 +1,41 @@
+model:
+ base_learning_rate: 4.5e-6
+ target: taming.models.vqgan.VQModel
+ params:
+ embed_dim: 256
+ n_embed: 1024
+ image_key: depth
+ ddconfig:
+ double_z: False
+ z_channels: 256
+ resolution: 256
+ in_channels: 1
+ out_ch: 1
+ ch: 128
+ ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
+ num_res_blocks: 2
+ attn_resolutions: [16]
+ dropout: 0.0
+
+ lossconfig:
+ target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
+ params:
+ disc_conditional: False
+ disc_in_channels: 1
+ disc_start: 50001
+ disc_weight: 0.75
+ codebook_weight: 1.0
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 3
+ num_workers: 8
+ train:
+ target: taming.data.imagenet.ImageNetTrainWithDepth
+ params:
+ size: 256
+ validation:
+ target: taming.data.imagenet.ImageNetValidationWithDepth
+ params:
+ size: 256
diff --git a/repositories/taming-transformers/configs/open_images_scene_images_transformer.yaml b/repositories/taming-transformers/configs/open_images_scene_images_transformer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f4e41e0d67d3e3eb17509862e063e4d626b06d4b
--- /dev/null
+++ b/repositories/taming-transformers/configs/open_images_scene_images_transformer.yaml
@@ -0,0 +1,86 @@
+model:
+ base_learning_rate: 4.5e-06
+ target: taming.models.cond_transformer.Net2NetTransformer
+ params:
+ cond_stage_key: objects_bbox
+ transformer_config:
+ target: taming.modules.transformer.mingpt.GPT
+ params:
+ vocab_size: 8192
+ block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
+ n_layer: 36
+ n_head: 16
+ n_embd: 1536
+ embd_pdrop: 0.1
+ resid_pdrop: 0.1
+ attn_pdrop: 0.1
+ first_stage_config:
+ target: taming.models.vqgan.VQModel
+ params:
+ ckpt_path: /path/to/coco_oi_epoch12.ckpt # https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/
+ embed_dim: 256
+ n_embed: 8192
+ ddconfig:
+ double_z: false
+ z_channels: 256
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 1
+ - 2
+ - 2
+ - 4
+ num_res_blocks: 2
+ attn_resolutions:
+ - 16
+ dropout: 0.0
+ lossconfig:
+ target: taming.modules.losses.DummyLoss
+ cond_stage_config:
+ target: taming.models.dummy_cond_stage.DummyCondStage
+ params:
+ conditional_key: objects_bbox
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 6
+ train:
+ target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
+ params:
+ data_path: data/open_images_annotations_100 # substitute with path to full dataset
+ split: train
+ keys: [image, objects_bbox, file_name, annotations]
+ no_tokens: 8192
+ target_image_size: 256
+ category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
+ category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
+ min_object_area: 0.0001
+ min_objects_per_image: 2
+ max_objects_per_image: 30
+ crop_method: random-2d
+ random_flip: true
+ use_group_parameter: true
+ use_additional_parameters: true
+ encode_crop: true
+ validation:
+ target: taming.data.annotated_objects_open_images.AnnotatedObjectsOpenImages
+ params:
+ data_path: data/open_images_annotations_100 # substitute with path to full dataset
+ split: validation
+ keys: [image, objects_bbox, file_name, annotations]
+ no_tokens: 8192
+ target_image_size: 256
+ category_allow_list_target: taming.data.open_images_helper.top_300_classes_plus_coco_compatibility
+ category_mapping_target: taming.data.open_images_helper.open_images_unify_categories_for_coco
+ min_object_area: 0.0001
+ min_objects_per_image: 2
+ max_objects_per_image: 30
+ crop_method: center
+ random_flip: false
+ use_group_parameter: true
+ use_additional_parameters: true
+ encode_crop: true
diff --git a/repositories/taming-transformers/configs/sflckr_cond_stage.yaml b/repositories/taming-transformers/configs/sflckr_cond_stage.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d48b50a700c4b44098ce4f3c752d5a4d7158f8a9
--- /dev/null
+++ b/repositories/taming-transformers/configs/sflckr_cond_stage.yaml
@@ -0,0 +1,43 @@
+model:
+ base_learning_rate: 4.5e-06
+ target: taming.models.vqgan.VQSegmentationModel
+ params:
+ embed_dim: 256
+ n_embed: 1024
+ image_key: "segmentation"
+ n_labels: 182
+ ddconfig:
+ double_z: false
+ z_channels: 256
+ resolution: 256
+ in_channels: 182
+ out_ch: 182
+ ch: 128
+ ch_mult:
+ - 1
+ - 1
+ - 2
+ - 2
+ - 4
+ num_res_blocks: 2
+ attn_resolutions:
+ - 16
+ dropout: 0.0
+
+ lossconfig:
+ target: taming.modules.losses.segmentation.BCELossWithQuant
+ params:
+ codebook_weight: 1.0
+
+data:
+ target: cutlit.DataModuleFromConfig
+ params:
+ batch_size: 12
+ train:
+ target: taming.data.sflckr.Examples # adjust
+ params:
+ size: 256
+ validation:
+ target: taming.data.sflckr.Examples # adjust
+ params:
+ size: 256
diff --git a/repositories/taming-transformers/environment.yaml b/repositories/taming-transformers/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3fbba586e55dbe64184d006319fad969805ef16f
--- /dev/null
+++ b/repositories/taming-transformers/environment.yaml
@@ -0,0 +1,25 @@
+name: taming
+channels:
+ - pytorch
+ - defaults
+dependencies:
+ - python=3.8.5
+ - pip=20.3
+ - cudatoolkit=10.2
+ - pytorch=1.7.0
+ - torchvision=0.8.1
+ - numpy=1.19.2
+ - pip:
+ - albumentations==0.4.3
+ - opencv-python==4.1.2.30
+ - pudb==2019.2
+ - imageio==2.9.0
+ - imageio-ffmpeg==0.4.2
+ - pytorch-lightning==1.0.8
+ - omegaconf==2.0.0
+ - test-tube>=0.7.5
+ - streamlit>=0.73.1
+ - einops==0.3.0
+ - more-itertools>=8.0.0
+ - transformers==4.3.1
+ - -e .
diff --git a/repositories/taming-transformers/main.py b/repositories/taming-transformers/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d83cb21c1dc5d0f5d7f396479c74e64691ec364
--- /dev/null
+++ b/repositories/taming-transformers/main.py
@@ -0,0 +1,585 @@
+import argparse, os, sys, datetime, glob, importlib
+from omegaconf import OmegaConf
+import numpy as np
+from PIL import Image
+import torch
+import torchvision
+from torch.utils.data import random_split, DataLoader, Dataset
+import pytorch_lightning as pl
+from pytorch_lightning import seed_everything
+from pytorch_lightning.trainer import Trainer
+from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
+from pytorch_lightning.utilities.distributed import rank_zero_only
+
+from taming.data.utils import custom_collate
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def get_parser(**parser_kwargs):
+ def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+ parser = argparse.ArgumentParser(**parser_kwargs)
+ parser.add_argument(
+ "-n",
+ "--name",
+ type=str,
+ const=True,
+ default="",
+ nargs="?",
+ help="postfix for logdir",
+ )
+ parser.add_argument(
+ "-r",
+ "--resume",
+ type=str,
+ const=True,
+ default="",
+ nargs="?",
+ help="resume from logdir or checkpoint in logdir",
+ )
+ parser.add_argument(
+ "-b",
+ "--base",
+ nargs="*",
+ metavar="base_config.yaml",
+ help="paths to base configs. Loaded from left-to-right. "
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
+ default=list(),
+ )
+ parser.add_argument(
+ "-t",
+ "--train",
+ type=str2bool,
+ const=True,
+ default=False,
+ nargs="?",
+ help="train",
+ )
+ parser.add_argument(
+ "--no-test",
+ type=str2bool,
+ const=True,
+ default=False,
+ nargs="?",
+ help="disable test",
+ )
+ parser.add_argument("-p", "--project", help="name of new or path to existing project")
+ parser.add_argument(
+ "-d",
+ "--debug",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="enable post-mortem debugging",
+ )
+ parser.add_argument(
+ "-s",
+ "--seed",
+ type=int,
+ default=23,
+ help="seed for seed_everything",
+ )
+ parser.add_argument(
+ "-f",
+ "--postfix",
+ type=str,
+ default="",
+ help="post-postfix for default name",
+ )
+
+ return parser
+
+
+def nondefault_trainer_args(opt):
+ parser = argparse.ArgumentParser()
+ parser = Trainer.add_argparse_args(parser)
+ args = parser.parse_args([])
+ return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+class WrappedDataset(Dataset):
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
+ def __init__(self, dataset):
+ self.data = dataset
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+
+class DataModuleFromConfig(pl.LightningDataModule):
+ def __init__(self, batch_size, train=None, validation=None, test=None,
+ wrap=False, num_workers=None):
+ super().__init__()
+ self.batch_size = batch_size
+ self.dataset_configs = dict()
+ self.num_workers = num_workers if num_workers is not None else batch_size*2
+ if train is not None:
+ self.dataset_configs["train"] = train
+ self.train_dataloader = self._train_dataloader
+ if validation is not None:
+ self.dataset_configs["validation"] = validation
+ self.val_dataloader = self._val_dataloader
+ if test is not None:
+ self.dataset_configs["test"] = test
+ self.test_dataloader = self._test_dataloader
+ self.wrap = wrap
+
+ def prepare_data(self):
+ for data_cfg in self.dataset_configs.values():
+ instantiate_from_config(data_cfg)
+
+ def setup(self, stage=None):
+ self.datasets = dict(
+ (k, instantiate_from_config(self.dataset_configs[k]))
+ for k in self.dataset_configs)
+ if self.wrap:
+ for k in self.datasets:
+ self.datasets[k] = WrappedDataset(self.datasets[k])
+
+ def _train_dataloader(self):
+ return DataLoader(self.datasets["train"], batch_size=self.batch_size,
+ num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate)
+
+ def _val_dataloader(self):
+ return DataLoader(self.datasets["validation"],
+ batch_size=self.batch_size,
+ num_workers=self.num_workers, collate_fn=custom_collate)
+
+ def _test_dataloader(self):
+ return DataLoader(self.datasets["test"], batch_size=self.batch_size,
+ num_workers=self.num_workers, collate_fn=custom_collate)
+
+
+class SetupCallback(Callback):
+ def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
+ super().__init__()
+ self.resume = resume
+ self.now = now
+ self.logdir = logdir
+ self.ckptdir = ckptdir
+ self.cfgdir = cfgdir
+ self.config = config
+ self.lightning_config = lightning_config
+
+ def on_pretrain_routine_start(self, trainer, pl_module):
+ if trainer.global_rank == 0:
+ # Create logdirs and save configs
+ os.makedirs(self.logdir, exist_ok=True)
+ os.makedirs(self.ckptdir, exist_ok=True)
+ os.makedirs(self.cfgdir, exist_ok=True)
+
+ print("Project config")
+ print(self.config.pretty())
+ OmegaConf.save(self.config,
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
+
+ print("Lightning config")
+ print(self.lightning_config.pretty())
+ OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
+
+ else:
+ # ModelCheckpoint callback created log directory --- remove it
+ if not self.resume and os.path.exists(self.logdir):
+ dst, name = os.path.split(self.logdir)
+ dst = os.path.join(dst, "child_runs", name)
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
+ try:
+ os.rename(self.logdir, dst)
+ except FileNotFoundError:
+ pass
+
+
+class ImageLogger(Callback):
+ def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True):
+ super().__init__()
+ self.batch_freq = batch_frequency
+ self.max_images = max_images
+ self.logger_log_images = {
+ pl.loggers.WandbLogger: self._wandb,
+ pl.loggers.TestTubeLogger: self._testtube,
+ }
+ self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
+ if not increase_log_steps:
+ self.log_steps = [self.batch_freq]
+ self.clamp = clamp
+
+ @rank_zero_only
+ def _wandb(self, pl_module, images, batch_idx, split):
+ raise ValueError("No way wandb")
+ grids = dict()
+ for k in images:
+ grid = torchvision.utils.make_grid(images[k])
+ grids[f"{split}/{k}"] = wandb.Image(grid)
+ pl_module.logger.experiment.log(grids)
+
+ @rank_zero_only
+ def _testtube(self, pl_module, images, batch_idx, split):
+ for k in images:
+ grid = torchvision.utils.make_grid(images[k])
+ grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
+
+ tag = f"{split}/{k}"
+ pl_module.logger.experiment.add_image(
+ tag, grid,
+ global_step=pl_module.global_step)
+
+ @rank_zero_only
+ def log_local(self, save_dir, split, images,
+ global_step, current_epoch, batch_idx):
+ root = os.path.join(save_dir, "images", split)
+ for k in images:
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
+
+ grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
+ grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
+ grid = grid.numpy()
+ grid = (grid*255).astype(np.uint8)
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
+ k,
+ global_step,
+ current_epoch,
+ batch_idx)
+ path = os.path.join(root, filename)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ Image.fromarray(grid).save(path)
+
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
+ if (self.check_frequency(batch_idx) and # batch_idx % self.batch_freq == 0
+ hasattr(pl_module, "log_images") and
+ callable(pl_module.log_images) and
+ self.max_images > 0):
+ logger = type(pl_module.logger)
+
+ is_train = pl_module.training
+ if is_train:
+ pl_module.eval()
+
+ with torch.no_grad():
+ images = pl_module.log_images(batch, split=split, pl_module=pl_module)
+
+ for k in images:
+ N = min(images[k].shape[0], self.max_images)
+ images[k] = images[k][:N]
+ if isinstance(images[k], torch.Tensor):
+ images[k] = images[k].detach().cpu()
+ if self.clamp:
+ images[k] = torch.clamp(images[k], -1., 1.)
+
+ self.log_local(pl_module.logger.save_dir, split, images,
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
+
+ logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
+ logger_log_images(pl_module, images, pl_module.global_step, split)
+
+ if is_train:
+ pl_module.train()
+
+ def check_frequency(self, batch_idx):
+ if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
+ try:
+ self.log_steps.pop(0)
+ except IndexError:
+ pass
+ return True
+ return False
+
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
+ self.log_img(pl_module, batch, batch_idx, split="train")
+
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
+ self.log_img(pl_module, batch, batch_idx, split="val")
+
+
+
+if __name__ == "__main__":
+ # custom parser to specify config files, train, test and debug mode,
+ # postfix, resume.
+ # `--key value` arguments are interpreted as arguments to the trainer.
+ # `nested.key=value` arguments are interpreted as config parameters.
+ # configs are merged from left-to-right followed by command line parameters.
+
+ # model:
+ # base_learning_rate: float
+ # target: path to lightning module
+ # params:
+ # key: value
+ # data:
+ # target: main.DataModuleFromConfig
+ # params:
+ # batch_size: int
+ # wrap: bool
+ # train:
+ # target: path to train dataset
+ # params:
+ # key: value
+ # validation:
+ # target: path to validation dataset
+ # params:
+ # key: value
+ # test:
+ # target: path to test dataset
+ # params:
+ # key: value
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
+ # trainer:
+ # additional arguments to trainer
+ # logger:
+ # logger to instantiate
+ # modelcheckpoint:
+ # modelcheckpoint to instantiate
+ # callbacks:
+ # callback1:
+ # target: importpath
+ # params:
+ # key: value
+
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+
+ # add cwd for convenience and to make classes in this file available when
+ # running as `python main.py`
+ # (in particular `main.DataModuleFromConfig`)
+ sys.path.append(os.getcwd())
+
+ parser = get_parser()
+ parser = Trainer.add_argparse_args(parser)
+
+ opt, unknown = parser.parse_known_args()
+ if opt.name and opt.resume:
+ raise ValueError(
+ "-n/--name and -r/--resume cannot be specified both."
+ "If you want to resume training in a new log folder, "
+ "use -n/--name in combination with --resume_from_checkpoint"
+ )
+ if opt.resume:
+ if not os.path.exists(opt.resume):
+ raise ValueError("Cannot find {}".format(opt.resume))
+ if os.path.isfile(opt.resume):
+ paths = opt.resume.split("/")
+ idx = len(paths)-paths[::-1].index("logs")+1
+ logdir = "/".join(paths[:idx])
+ ckpt = opt.resume
+ else:
+ assert os.path.isdir(opt.resume), opt.resume
+ logdir = opt.resume.rstrip("/")
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
+
+ opt.resume_from_checkpoint = ckpt
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
+ opt.base = base_configs+opt.base
+ _tmp = logdir.split("/")
+ nowname = _tmp[_tmp.index("logs")+1]
+ else:
+ if opt.name:
+ name = "_"+opt.name
+ elif opt.base:
+ cfg_fname = os.path.split(opt.base[0])[-1]
+ cfg_name = os.path.splitext(cfg_fname)[0]
+ name = "_"+cfg_name
+ else:
+ name = ""
+ nowname = now+name+opt.postfix
+ logdir = os.path.join("logs", nowname)
+
+ ckptdir = os.path.join(logdir, "checkpoints")
+ cfgdir = os.path.join(logdir, "configs")
+ seed_everything(opt.seed)
+
+ try:
+ # init and save configs
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
+ cli = OmegaConf.from_dotlist(unknown)
+ config = OmegaConf.merge(*configs, cli)
+ lightning_config = config.pop("lightning", OmegaConf.create())
+ # merge trainer cli with config
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
+ # default to ddp
+ trainer_config["distributed_backend"] = "ddp"
+ for k in nondefault_trainer_args(opt):
+ trainer_config[k] = getattr(opt, k)
+ if not "gpus" in trainer_config:
+ del trainer_config["distributed_backend"]
+ cpu = True
+ else:
+ gpuinfo = trainer_config["gpus"]
+ print(f"Running on GPUs {gpuinfo}")
+ cpu = False
+ trainer_opt = argparse.Namespace(**trainer_config)
+ lightning_config.trainer = trainer_config
+
+ # model
+ model = instantiate_from_config(config.model)
+
+ # trainer and callbacks
+ trainer_kwargs = dict()
+
+ # default logger configs
+ # NOTE wandb < 0.10.0 interferes with shutdown
+ # wandb >= 0.10.0 seems to fix it but still interferes with pudb
+ # debugging (wrongly sized pudb ui)
+ # thus prefer testtube for now
+ default_logger_cfgs = {
+ "wandb": {
+ "target": "pytorch_lightning.loggers.WandbLogger",
+ "params": {
+ "name": nowname,
+ "save_dir": logdir,
+ "offline": opt.debug,
+ "id": nowname,
+ }
+ },
+ "testtube": {
+ "target": "pytorch_lightning.loggers.TestTubeLogger",
+ "params": {
+ "name": "testtube",
+ "save_dir": logdir,
+ }
+ },
+ }
+ default_logger_cfg = default_logger_cfgs["testtube"]
+ logger_cfg = lightning_config.logger or OmegaConf.create()
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
+
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
+ # specify which metric is used to determine best models
+ default_modelckpt_cfg = {
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
+ "params": {
+ "dirpath": ckptdir,
+ "filename": "{epoch:06}",
+ "verbose": True,
+ "save_last": True,
+ }
+ }
+ if hasattr(model, "monitor"):
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
+ default_modelckpt_cfg["params"]["monitor"] = model.monitor
+ default_modelckpt_cfg["params"]["save_top_k"] = 3
+
+ modelckpt_cfg = lightning_config.modelcheckpoint or OmegaConf.create()
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
+ trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
+
+ # add callback which sets up log directory
+ default_callbacks_cfg = {
+ "setup_callback": {
+ "target": "main.SetupCallback",
+ "params": {
+ "resume": opt.resume,
+ "now": now,
+ "logdir": logdir,
+ "ckptdir": ckptdir,
+ "cfgdir": cfgdir,
+ "config": config,
+ "lightning_config": lightning_config,
+ }
+ },
+ "image_logger": {
+ "target": "main.ImageLogger",
+ "params": {
+ "batch_frequency": 750,
+ "max_images": 4,
+ "clamp": True
+ }
+ },
+ "learning_rate_logger": {
+ "target": "main.LearningRateMonitor",
+ "params": {
+ "logging_interval": "step",
+ #"log_momentum": True
+ }
+ },
+ }
+ callbacks_cfg = lightning_config.callbacks or OmegaConf.create()
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
+ trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
+
+ trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
+
+ # data
+ data = instantiate_from_config(config.data)
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
+ # calling these ourselves should not be necessary but it is.
+ # lightning still takes care of proper multiprocessing though
+ data.prepare_data()
+ data.setup()
+
+ # configure learning rate
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
+ if not cpu:
+ ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
+ else:
+ ngpu = 1
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
+ print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
+
+ # allow checkpointing via USR1
+ def melk(*args, **kwargs):
+ # run all checkpoint hooks
+ if trainer.global_rank == 0:
+ print("Summoning checkpoint.")
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
+ trainer.save_checkpoint(ckpt_path)
+
+ def divein(*args, **kwargs):
+ if trainer.global_rank == 0:
+ import pudb; pudb.set_trace()
+
+ import signal
+ signal.signal(signal.SIGUSR1, melk)
+ signal.signal(signal.SIGUSR2, divein)
+
+ # run
+ if opt.train:
+ try:
+ trainer.fit(model, data)
+ except Exception:
+ melk()
+ raise
+ if not opt.no_test and not trainer.interrupted:
+ trainer.test(model, data)
+ except Exception:
+ if opt.debug and trainer.global_rank==0:
+ try:
+ import pudb as debugger
+ except ImportError:
+ import pdb as debugger
+ debugger.post_mortem()
+ raise
+ finally:
+ # move newly created debug project to debug_runs
+ if opt.debug and not opt.resume and trainer.global_rank==0:
+ dst, name = os.path.split(logdir)
+ dst = os.path.join(dst, "debug_runs", name)
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
+ os.rename(logdir, dst)
diff --git a/repositories/taming-transformers/scripts/extract_depth.py b/repositories/taming-transformers/scripts/extract_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6aa0d80c63a3e580fa28e0f2c7af4e9ae003b64
--- /dev/null
+++ b/repositories/taming-transformers/scripts/extract_depth.py
@@ -0,0 +1,112 @@
+import os
+import torch
+import numpy as np
+from tqdm import trange
+from PIL import Image
+
+
+def get_state(gpu):
+ import torch
+ midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
+ if gpu:
+ midas.cuda()
+ midas.eval()
+
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
+ transform = midas_transforms.default_transform
+
+ state = {"model": midas,
+ "transform": transform}
+ return state
+
+
+def depth_to_rgba(x):
+ assert x.dtype == np.float32
+ assert len(x.shape) == 2
+ y = x.copy()
+ y.dtype = np.uint8
+ y = y.reshape(x.shape+(4,))
+ return np.ascontiguousarray(y)
+
+
+def rgba_to_depth(x):
+ assert x.dtype == np.uint8
+ assert len(x.shape) == 3 and x.shape[2] == 4
+ y = x.copy()
+ y.dtype = np.float32
+ y = y.reshape(x.shape[:2])
+ return np.ascontiguousarray(y)
+
+
+def run(x, state):
+ model = state["model"]
+ transform = state["transform"]
+ hw = x.shape[:2]
+ with torch.no_grad():
+ prediction = model(transform((x + 1.0) * 127.5).cuda())
+ prediction = torch.nn.functional.interpolate(
+ prediction.unsqueeze(1),
+ size=hw,
+ mode="bicubic",
+ align_corners=False,
+ ).squeeze()
+ output = prediction.cpu().numpy()
+ return output
+
+
+def get_filename(relpath, level=-2):
+ # save class folder structure and filename:
+ fn = relpath.split(os.sep)[level:]
+ folder = fn[-2]
+ file = fn[-1].split('.')[0]
+ return folder, file
+
+
+def save_depth(dataset, path, debug=False):
+ os.makedirs(path)
+ N = len(dset)
+ if debug:
+ N = 10
+ state = get_state(gpu=True)
+ for idx in trange(N, desc="Data"):
+ ex = dataset[idx]
+ image, relpath = ex["image"], ex["relpath"]
+ folder, filename = get_filename(relpath)
+ # prepare
+ folderabspath = os.path.join(path, folder)
+ os.makedirs(folderabspath, exist_ok=True)
+ savepath = os.path.join(folderabspath, filename)
+ # run model
+ xout = run(image, state)
+ I = depth_to_rgba(xout)
+ Image.fromarray(I).save("{}.png".format(savepath))
+
+
+if __name__ == "__main__":
+ from taming.data.imagenet import ImageNetTrain, ImageNetValidation
+ out = "data/imagenet_depth"
+ if not os.path.exists(out):
+ print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
+ "(be prepared that the output size will be larger than ImageNet itself).")
+ exit(1)
+
+ # go
+ dset = ImageNetValidation()
+ abspath = os.path.join(out, "val")
+ if os.path.exists(abspath):
+ print("{} exists - not doing anything.".format(abspath))
+ else:
+ print("preparing {}".format(abspath))
+ save_depth(dset, abspath)
+ print("done with validation split")
+
+ dset = ImageNetTrain()
+ abspath = os.path.join(out, "train")
+ if os.path.exists(abspath):
+ print("{} exists - not doing anything.".format(abspath))
+ else:
+ print("preparing {}".format(abspath))
+ save_depth(dset, abspath)
+ print("done with train split")
+
+ print("done done.")
diff --git a/repositories/taming-transformers/scripts/extract_segmentation.py b/repositories/taming-transformers/scripts/extract_segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..235b3c4b4575981b7533ce18bceaff97e05b55f9
--- /dev/null
+++ b/repositories/taming-transformers/scripts/extract_segmentation.py
@@ -0,0 +1,130 @@
+import sys, os
+import numpy as np
+import scipy
+import torch
+import torch.nn as nn
+from scipy import ndimage
+from tqdm import tqdm, trange
+from PIL import Image
+import torch.hub
+import torchvision
+import torch.nn.functional as F
+
+# download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from
+# https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth
+# and put the path here
+CKPT_PATH = "TODO"
+
+rescale = lambda x: (x + 1.) / 2.
+
+def rescale_bgr(x):
+ x = (x+1)*127.5
+ x = torch.flip(x, dims=[0])
+ return x
+
+
+class COCOStuffSegmenter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.n_labels = 182
+ model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels)
+ ckpt_path = CKPT_PATH
+ model.load_state_dict(torch.load(ckpt_path))
+ self.model = model
+
+ normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
+ self.image_transform = torchvision.transforms.Compose([
+ torchvision.transforms.Lambda(lambda image: torch.stack(
+ [normalize(rescale_bgr(x)) for x in image]))
+ ])
+
+ def forward(self, x, upsample=None):
+ x = self._pre_process(x)
+ x = self.model(x)
+ if upsample is not None:
+ x = torch.nn.functional.upsample_bilinear(x, size=upsample)
+ return x
+
+ def _pre_process(self, x):
+ x = self.image_transform(x)
+ return x
+
+ @property
+ def mean(self):
+ # bgr
+ return [104.008, 116.669, 122.675]
+
+ @property
+ def std(self):
+ return [1.0, 1.0, 1.0]
+
+ @property
+ def input_size(self):
+ return [3, 224, 224]
+
+
+def run_model(img, model):
+ model = model.eval()
+ with torch.no_grad():
+ segmentation = model(img, upsample=(img.shape[2], img.shape[3]))
+ segmentation = torch.argmax(segmentation, dim=1, keepdim=True)
+ return segmentation.detach().cpu()
+
+
+def get_input(batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+ return x.float()
+
+
+def save_segmentation(segmentation, path):
+ # --> class label to uint8, save as png
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ assert len(segmentation.shape)==4
+ assert segmentation.shape[0]==1
+ for seg in segmentation:
+ seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8)
+ seg = Image.fromarray(seg)
+ seg.save(path)
+
+
+def iterate_dataset(dataloader, destpath, model):
+ os.makedirs(destpath, exist_ok=True)
+ num_processed = 0
+ for i, batch in tqdm(enumerate(dataloader), desc="Data"):
+ try:
+ img = get_input(batch, "image")
+ img = img.cuda()
+ seg = run_model(img, model)
+
+ path = batch["relative_file_path_"][0]
+ path = os.path.splitext(path)[0]
+
+ path = os.path.join(destpath, path + ".png")
+ save_segmentation(seg, path)
+ num_processed += 1
+ except Exception as e:
+ print(e)
+ print("but anyhow..")
+
+ print("Processed {} files. Bye.".format(num_processed))
+
+
+from taming.data.sflckr import Examples
+from torch.utils.data import DataLoader
+
+if __name__ == "__main__":
+ dest = sys.argv[1]
+ batchsize = 1
+ print("Running with batch-size {}, saving to {}...".format(batchsize, dest))
+
+ model = COCOStuffSegmenter({}).cuda()
+ print("Instantiated model.")
+
+ dataset = Examples()
+ dloader = DataLoader(dataset, batch_size=batchsize)
+ iterate_dataset(dataloader=dloader, destpath=dest, model=model)
+ print("done.")
diff --git a/repositories/taming-transformers/scripts/extract_submodel.py b/repositories/taming-transformers/scripts/extract_submodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..559bc5e04281a7cf833a82e3cd48627b20f1a76d
--- /dev/null
+++ b/repositories/taming-transformers/scripts/extract_submodel.py
@@ -0,0 +1,17 @@
+import torch
+import sys
+
+if __name__ == "__main__":
+ inpath = sys.argv[1]
+ outpath = sys.argv[2]
+ submodel = "cond_stage_model"
+ if len(sys.argv) > 3:
+ submodel = sys.argv[3]
+
+ print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
+
+ sd = torch.load(inpath, map_location="cpu")
+ new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
+ for k,v in sd["state_dict"].items()
+ if k.startswith("cond_stage_model"))}
+ torch.save(new_sd, outpath)
diff --git a/repositories/taming-transformers/scripts/make_samples.py b/repositories/taming-transformers/scripts/make_samples.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e4d6995cd41cc07b4e8861cb941c6052b0f5517
--- /dev/null
+++ b/repositories/taming-transformers/scripts/make_samples.py
@@ -0,0 +1,292 @@
+import argparse, os, sys, glob, math, time
+import torch
+import numpy as np
+from omegaconf import OmegaConf
+from PIL import Image
+from main import instantiate_from_config, DataModuleFromConfig
+from torch.utils.data import DataLoader
+from torch.utils.data.dataloader import default_collate
+from tqdm import trange
+
+
+def save_image(x, path):
+ c,h,w = x.shape
+ assert c==3
+ x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
+ Image.fromarray(x).save(path)
+
+
+@torch.no_grad()
+def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
+ if len(dsets.datasets) > 1:
+ split = sorted(dsets.datasets.keys())[0]
+ dset = dsets.datasets[split]
+ else:
+ dset = next(iter(dsets.datasets.values()))
+ print("Dataset: ", dset.__class__.__name__)
+ for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
+ indices = list(range(start_idx, start_idx+batch_size))
+ example = default_collate([dset[i] for i in indices])
+
+ x = model.get_input("image", example).to(model.device)
+ for i in range(x.shape[0]):
+ save_image(x[i], os.path.join(outdir, "originals",
+ "{:06}.png".format(indices[i])))
+
+ cond_key = model.cond_stage_key
+ c = model.get_input(cond_key, example).to(model.device)
+
+ scale_factor = 1.0
+ quant_z, z_indices = model.encode_to_z(x)
+ quant_c, c_indices = model.encode_to_c(c)
+
+ cshape = quant_z.shape
+
+ xrec = model.first_stage_model.decode(quant_z)
+ for i in range(xrec.shape[0]):
+ save_image(xrec[i], os.path.join(outdir, "reconstructions",
+ "{:06}.png".format(indices[i])))
+
+ if cond_key == "segmentation":
+ # get image from segmentation mask
+ num_classes = c.shape[1]
+ c = torch.argmax(c, dim=1, keepdim=True)
+ c = torch.nn.functional.one_hot(c, num_classes=num_classes)
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
+ c = model.cond_stage_model.to_rgb(c)
+
+ idx = z_indices
+
+ half_sample = False
+ if half_sample:
+ start = idx.shape[1]//2
+ else:
+ start = 0
+
+ idx[:,start:] = 0
+ idx = idx.reshape(cshape[0],cshape[2],cshape[3])
+ start_i = start//cshape[3]
+ start_j = start %cshape[3]
+
+ cidx = c_indices
+ cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
+
+ sample = True
+
+ for i in range(start_i,cshape[2]-0):
+ if i <= 8:
+ local_i = i
+ elif cshape[2]-i < 8:
+ local_i = 16-(cshape[2]-i)
+ else:
+ local_i = 8
+ for j in range(start_j,cshape[3]-0):
+ if j <= 8:
+ local_j = j
+ elif cshape[3]-j < 8:
+ local_j = 16-(cshape[3]-j)
+ else:
+ local_j = 8
+
+ i_start = i-local_i
+ i_end = i_start+16
+ j_start = j-local_j
+ j_end = j_start+16
+ patch = idx[:,i_start:i_end,j_start:j_end]
+ patch = patch.reshape(patch.shape[0],-1)
+ cpatch = cidx[:, i_start:i_end, j_start:j_end]
+ cpatch = cpatch.reshape(cpatch.shape[0], -1)
+ patch = torch.cat((cpatch, patch), dim=1)
+ logits,_ = model.transformer(patch[:,:-1])
+ logits = logits[:, -256:, :]
+ logits = logits.reshape(cshape[0],16,16,-1)
+ logits = logits[:,local_i,local_j,:]
+
+ logits = logits/temperature
+
+ if top_k is not None:
+ logits = model.top_k_logits(logits, top_k)
+ # apply softmax to convert to probabilities
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ # sample from the distribution or take the most likely
+ if sample:
+ ix = torch.multinomial(probs, num_samples=1)
+ else:
+ _, ix = torch.topk(probs, k=1, dim=-1)
+ idx[:,i,j] = ix
+
+ xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
+ for i in range(xsample.shape[0]):
+ save_image(xsample[i], os.path.join(outdir, "samples",
+ "{:06}.png".format(indices[i])))
+
+
+def get_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-r",
+ "--resume",
+ type=str,
+ nargs="?",
+ help="load from logdir or checkpoint in logdir",
+ )
+ parser.add_argument(
+ "-b",
+ "--base",
+ nargs="*",
+ metavar="base_config.yaml",
+ help="paths to base configs. Loaded from left-to-right. "
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
+ default=list(),
+ )
+ parser.add_argument(
+ "-c",
+ "--config",
+ nargs="?",
+ metavar="single_config.yaml",
+ help="path to single config. If specified, base configs will be ignored "
+ "(except for the last one if left unspecified).",
+ const=True,
+ default="",
+ )
+ parser.add_argument(
+ "--ignore_base_data",
+ action="store_true",
+ help="Ignore data specification from base configs. Useful if you want "
+ "to specify a custom datasets on the command line.",
+ )
+ parser.add_argument(
+ "--outdir",
+ required=True,
+ type=str,
+ help="Where to write outputs to.",
+ )
+ parser.add_argument(
+ "--top_k",
+ type=int,
+ default=100,
+ help="Sample from among top-k predictions.",
+ )
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ default=1.0,
+ help="Sampling temperature.",
+ )
+ return parser
+
+
+def load_model_from_config(config, sd, gpu=True, eval_mode=True):
+ if "ckpt_path" in config.params:
+ print("Deleting the restore-ckpt path from the config...")
+ config.params.ckpt_path = None
+ if "downsample_cond_size" in config.params:
+ print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
+ config.params.downsample_cond_size = -1
+ config.params["downsample_cond_factor"] = 0.5
+ try:
+ if "ckpt_path" in config.params.first_stage_config.params:
+ config.params.first_stage_config.params.ckpt_path = None
+ print("Deleting the first-stage restore-ckpt path from the config...")
+ if "ckpt_path" in config.params.cond_stage_config.params:
+ config.params.cond_stage_config.params.ckpt_path = None
+ print("Deleting the cond-stage restore-ckpt path from the config...")
+ except:
+ pass
+
+ model = instantiate_from_config(config)
+ if sd is not None:
+ missing, unexpected = model.load_state_dict(sd, strict=False)
+ print(f"Missing Keys in State Dict: {missing}")
+ print(f"Unexpected Keys in State Dict: {unexpected}")
+ if gpu:
+ model.cuda()
+ if eval_mode:
+ model.eval()
+ return {"model": model}
+
+
+def get_data(config):
+ # get data
+ data = instantiate_from_config(config.data)
+ data.prepare_data()
+ data.setup()
+ return data
+
+
+def load_model_and_dset(config, ckpt, gpu, eval_mode):
+ # get data
+ dsets = get_data(config) # calls data.config ...
+
+ # now load the specified checkpoint
+ if ckpt:
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ global_step = pl_sd["global_step"]
+ else:
+ pl_sd = {"state_dict": None}
+ global_step = None
+ model = load_model_from_config(config.model,
+ pl_sd["state_dict"],
+ gpu=gpu,
+ eval_mode=eval_mode)["model"]
+ return dsets, model, global_step
+
+
+if __name__ == "__main__":
+ sys.path.append(os.getcwd())
+
+ parser = get_parser()
+
+ opt, unknown = parser.parse_known_args()
+
+ ckpt = None
+ if opt.resume:
+ if not os.path.exists(opt.resume):
+ raise ValueError("Cannot find {}".format(opt.resume))
+ if os.path.isfile(opt.resume):
+ paths = opt.resume.split("/")
+ try:
+ idx = len(paths)-paths[::-1].index("logs")+1
+ except ValueError:
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
+ logdir = "/".join(paths[:idx])
+ ckpt = opt.resume
+ else:
+ assert os.path.isdir(opt.resume), opt.resume
+ logdir = opt.resume.rstrip("/")
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
+ print(f"logdir:{logdir}")
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
+ opt.base = base_configs+opt.base
+
+ if opt.config:
+ if type(opt.config) == str:
+ opt.base = [opt.config]
+ else:
+ opt.base = [opt.base[-1]]
+
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
+ cli = OmegaConf.from_dotlist(unknown)
+ if opt.ignore_base_data:
+ for config in configs:
+ if hasattr(config, "data"): del config["data"]
+ config = OmegaConf.merge(*configs, cli)
+
+ print(ckpt)
+ gpu = True
+ eval_mode = True
+ show_config = False
+ if show_config:
+ print(OmegaConf.to_container(config))
+
+ dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
+ print(f"Global step: {global_step}")
+
+ outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
+ opt.top_k,
+ opt.temperature))
+ os.makedirs(outdir, exist_ok=True)
+ print("Writing samples to ", outdir)
+ for k in ["originals", "reconstructions", "samples"]:
+ os.makedirs(os.path.join(outdir, k), exist_ok=True)
+ run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)
diff --git a/repositories/taming-transformers/scripts/make_scene_samples.py b/repositories/taming-transformers/scripts/make_scene_samples.py
new file mode 100644
index 0000000000000000000000000000000000000000..c096b98460874be0acbe5b85464593fbad4bedd0
--- /dev/null
+++ b/repositories/taming-transformers/scripts/make_scene_samples.py
@@ -0,0 +1,198 @@
+import glob
+import os
+import sys
+from itertools import product
+from pathlib import Path
+from typing import Literal, List, Optional, Tuple
+
+import numpy as np
+import torch
+from omegaconf import OmegaConf
+from pytorch_lightning import seed_everything
+from torch import Tensor
+from torchvision.utils import save_image
+from tqdm import tqdm
+
+from scripts.make_samples import get_parser, load_model_and_dset
+from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
+from taming.data.helper_types import BoundingBox, Annotation
+from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
+from taming.models.cond_transformer import Net2NetTransformer
+
+seed_everything(42424242)
+device: Literal['cuda', 'cpu'] = 'cuda'
+first_stage_factor = 16
+trained_on_res = 256
+
+
+def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
+ assert 0 <= coord < coord_max
+ coord_desired_center = (coord_window - 1) // 2
+ return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
+
+
+def get_crop_coordinates(x: int, y: int) -> BoundingBox:
+ WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
+ x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
+ y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
+ w = first_stage_factor / WIDTH
+ h = first_stage_factor / HEIGHT
+ return x0, y0, w, h
+
+
+def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
+ WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
+ x0 = _helper(predict_x, WIDTH, first_stage_factor)
+ y0 = _helper(predict_y, HEIGHT, first_stage_factor)
+ no_images = z_indices.shape[0]
+ cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
+ cut_out_2 = z_indices[:, predict_y, x0:predict_x]
+ return torch.cat((cut_out_1, cut_out_2), dim=1)
+
+
+@torch.no_grad()
+def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
+ conditional_builder: ObjectsCenterPointsConditionalBuilder, no_samples: int,
+ temperature: float, top_k: int) -> Tensor:
+ x_max, y_max = desired_z_shape[1], desired_z_shape[0]
+
+ annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
+
+ recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
+ if not recompute_conditional:
+ crop_coordinates = get_crop_coordinates(0, 0)
+ conditional_indices = conditional_builder.build(annotations, crop_coordinates)
+ c_indices = conditional_indices.to(device).repeat(no_samples, 1)
+ z_indices = torch.zeros((no_samples, 0), device=device).long()
+ output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
+ sample=True, top_k=top_k)
+ else:
+ output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
+ for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
+ crop_coordinates = get_crop_coordinates(predict_x, predict_y)
+ z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
+ conditional_indices = conditional_builder.build(annotations, crop_coordinates)
+ c_indices = conditional_indices.to(device).repeat(no_samples, 1)
+ new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
+ output_indices[:, predict_y, predict_x] = new_index[:, -1]
+ z_shape = (
+ no_samples,
+ model.first_stage_model.quantize.e_dim, # codebook embed_dim
+ desired_z_shape[0], # z_height
+ desired_z_shape[1] # z_width
+ )
+ x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
+ x_sample = x_sample.to('cpu')
+
+ plotter = conditional_builder.plot
+ figure_size = (x_sample.shape[2], x_sample.shape[3])
+ scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
+ plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
+ return torch.cat((x_sample, plot.unsqueeze(0)))
+
+
+def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
+ if not resolution_str.count(',') == 1:
+ raise ValueError("Give resolution as in 'height,width'")
+ res_h, res_w = resolution_str.split(',')
+ res_h = max(int(res_h), trained_on_res)
+ res_w = max(int(res_w), trained_on_res)
+ z_h = int(round(res_h/first_stage_factor))
+ z_w = int(round(res_w/first_stage_factor))
+ return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
+
+
+def add_arg_to_parser(parser):
+ parser.add_argument(
+ "-R",
+ "--resolution",
+ type=str,
+ default='256,256',
+ help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
+ )
+ parser.add_argument(
+ "-C",
+ "--conditional",
+ type=str,
+ default='objects_bbox',
+ help=f"objects_bbox or objects_center_points",
+ )
+ parser.add_argument(
+ "-N",
+ "--n_samples_per_layout",
+ type=int,
+ default=4,
+ help=f"how many samples to generate per layout",
+ )
+ return parser
+
+
+if __name__ == "__main__":
+ sys.path.append(os.getcwd())
+
+ parser = get_parser()
+ parser = add_arg_to_parser(parser)
+
+ opt, unknown = parser.parse_known_args()
+
+ ckpt = None
+ if opt.resume:
+ if not os.path.exists(opt.resume):
+ raise ValueError("Cannot find {}".format(opt.resume))
+ if os.path.isfile(opt.resume):
+ paths = opt.resume.split("/")
+ try:
+ idx = len(paths)-paths[::-1].index("logs")+1
+ except ValueError:
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
+ logdir = "/".join(paths[:idx])
+ ckpt = opt.resume
+ else:
+ assert os.path.isdir(opt.resume), opt.resume
+ logdir = opt.resume.rstrip("/")
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
+ print(f"logdir:{logdir}")
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
+ opt.base = base_configs+opt.base
+
+ if opt.config:
+ if type(opt.config) == str:
+ opt.base = [opt.config]
+ else:
+ opt.base = [opt.base[-1]]
+
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
+ cli = OmegaConf.from_dotlist(unknown)
+ if opt.ignore_base_data:
+ for config in configs:
+ if hasattr(config, "data"):
+ del config["data"]
+ config = OmegaConf.merge(*configs, cli)
+ desired_z_shape, desired_resolution = get_resolution(opt.resolution)
+ conditional = opt.conditional
+
+ print(ckpt)
+ gpu = True
+ eval_mode = True
+ show_config = False
+ if show_config:
+ print(OmegaConf.to_container(config))
+
+ dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
+ print(f"Global step: {global_step}")
+
+ data_loader = dsets.val_dataloader()
+ print(dsets.datasets["validation"].conditional_builders)
+ conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
+
+ outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
+ outdir.mkdir(exist_ok=True, parents=True)
+ print("Writing samples to ", outdir)
+
+ p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
+ for batch_no, batch in p_bar_1:
+ save_img: Optional[Tensor] = None
+ for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
+ imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
+ opt.n_samples_per_layout, opt.temperature, opt.top_k)
+ save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)
diff --git a/repositories/taming-transformers/scripts/reconstruction_usage.ipynb b/repositories/taming-transformers/scripts/reconstruction_usage.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..eb1d3aa48d5773ce5aece412ca396ed89efce6d9
--- /dev/null
+++ b/repositories/taming-transformers/scripts/reconstruction_usage.ipynb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f3d8f40cc9a0b1cbe4329e888ea25b5c018b290b4d92d3f1672a13f12d57589b
+size 13701699
diff --git a/repositories/taming-transformers/scripts/sample_conditional.py b/repositories/taming-transformers/scripts/sample_conditional.py
new file mode 100644
index 0000000000000000000000000000000000000000..174cf2af07c1a1ca4e6c35fc0e4f8d6e53591b56
--- /dev/null
+++ b/repositories/taming-transformers/scripts/sample_conditional.py
@@ -0,0 +1,355 @@
+import argparse, os, sys, glob, math, time
+import torch
+import numpy as np
+from omegaconf import OmegaConf
+import streamlit as st
+from streamlit import caching
+from PIL import Image
+from main import instantiate_from_config, DataModuleFromConfig
+from torch.utils.data import DataLoader
+from torch.utils.data.dataloader import default_collate
+
+
+rescale = lambda x: (x + 1.) / 2.
+
+
+def bchw_to_st(x):
+ return rescale(x.detach().cpu().numpy().transpose(0,2,3,1))
+
+def save_img(xstart, fname):
+ I = (xstart.clip(0,1)[0]*255).astype(np.uint8)
+ Image.fromarray(I).save(fname)
+
+
+
+def get_interactive_image(resize=False):
+ image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
+ if image is not None:
+ image = Image.open(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ print("upload image shape: {}".format(image.shape))
+ img = Image.fromarray(image)
+ if resize:
+ img = img.resize((256, 256))
+ image = np.array(img)
+ return image
+
+
+def single_image_to_torch(x, permute=True):
+ assert x is not None, "Please provide an image through the upload function"
+ x = np.array(x)
+ x = torch.FloatTensor(x/255.*2. - 1.)[None,...]
+ if permute:
+ x = x.permute(0, 3, 1, 2)
+ return x
+
+
+def pad_to_M(x, M):
+ hp = math.ceil(x.shape[2]/M)*M-x.shape[2]
+ wp = math.ceil(x.shape[3]/M)*M-x.shape[3]
+ x = torch.nn.functional.pad(x, (0,wp,0,hp,0,0,0,0))
+ return x
+
+@torch.no_grad()
+def run_conditional(model, dsets):
+ if len(dsets.datasets) > 1:
+ split = st.sidebar.radio("Split", sorted(dsets.datasets.keys()))
+ dset = dsets.datasets[split]
+ else:
+ dset = next(iter(dsets.datasets.values()))
+ batch_size = 1
+ start_index = st.sidebar.number_input("Example Index (Size: {})".format(len(dset)), value=0,
+ min_value=0,
+ max_value=len(dset)-batch_size)
+ indices = list(range(start_index, start_index+batch_size))
+
+ example = default_collate([dset[i] for i in indices])
+
+ x = model.get_input("image", example).to(model.device)
+
+ cond_key = model.cond_stage_key
+ c = model.get_input(cond_key, example).to(model.device)
+
+ scale_factor = st.sidebar.slider("Scale Factor", min_value=0.5, max_value=4.0, step=0.25, value=1.00)
+ if scale_factor != 1.0:
+ x = torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="bicubic")
+ c = torch.nn.functional.interpolate(c, scale_factor=scale_factor, mode="bicubic")
+
+ quant_z, z_indices = model.encode_to_z(x)
+ quant_c, c_indices = model.encode_to_c(c)
+
+ cshape = quant_z.shape
+
+ xrec = model.first_stage_model.decode(quant_z)
+ st.write("image: {}".format(x.shape))
+ st.image(bchw_to_st(x), clamp=True, output_format="PNG")
+ st.write("image reconstruction: {}".format(xrec.shape))
+ st.image(bchw_to_st(xrec), clamp=True, output_format="PNG")
+
+ if cond_key == "segmentation":
+ # get image from segmentation mask
+ num_classes = c.shape[1]
+ c = torch.argmax(c, dim=1, keepdim=True)
+ c = torch.nn.functional.one_hot(c, num_classes=num_classes)
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
+ c = model.cond_stage_model.to_rgb(c)
+
+ st.write(f"{cond_key}: {tuple(c.shape)}")
+ st.image(bchw_to_st(c), clamp=True, output_format="PNG")
+
+ idx = z_indices
+
+ half_sample = st.sidebar.checkbox("Image Completion", value=False)
+ if half_sample:
+ start = idx.shape[1]//2
+ else:
+ start = 0
+
+ idx[:,start:] = 0
+ idx = idx.reshape(cshape[0],cshape[2],cshape[3])
+ start_i = start//cshape[3]
+ start_j = start %cshape[3]
+
+ if not half_sample and quant_z.shape == quant_c.shape:
+ st.info("Setting idx to c_indices")
+ idx = c_indices.clone().reshape(cshape[0],cshape[2],cshape[3])
+
+ cidx = c_indices
+ cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
+
+ xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
+ st.image(bchw_to_st(xstart), clamp=True, output_format="PNG")
+
+ temperature = st.number_input("Temperature", value=1.0)
+ top_k = st.number_input("Top k", value=100)
+ sample = st.checkbox("Sample", value=True)
+ update_every = st.number_input("Update every", value=75)
+
+ st.text(f"Sampling shape ({cshape[2]},{cshape[3]})")
+
+ animate = st.checkbox("animate")
+ if animate:
+ import imageio
+ outvid = "sampling.mp4"
+ writer = imageio.get_writer(outvid, fps=25)
+ elapsed_t = st.empty()
+ info = st.empty()
+ st.text("Sampled")
+ if st.button("Sample"):
+ output = st.empty()
+ start_t = time.time()
+ for i in range(start_i,cshape[2]-0):
+ if i <= 8:
+ local_i = i
+ elif cshape[2]-i < 8:
+ local_i = 16-(cshape[2]-i)
+ else:
+ local_i = 8
+ for j in range(start_j,cshape[3]-0):
+ if j <= 8:
+ local_j = j
+ elif cshape[3]-j < 8:
+ local_j = 16-(cshape[3]-j)
+ else:
+ local_j = 8
+
+ i_start = i-local_i
+ i_end = i_start+16
+ j_start = j-local_j
+ j_end = j_start+16
+ elapsed_t.text(f"Time: {time.time() - start_t} seconds")
+ info.text(f"Step: ({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})")
+ patch = idx[:,i_start:i_end,j_start:j_end]
+ patch = patch.reshape(patch.shape[0],-1)
+ cpatch = cidx[:, i_start:i_end, j_start:j_end]
+ cpatch = cpatch.reshape(cpatch.shape[0], -1)
+ patch = torch.cat((cpatch, patch), dim=1)
+ logits,_ = model.transformer(patch[:,:-1])
+ logits = logits[:, -256:, :]
+ logits = logits.reshape(cshape[0],16,16,-1)
+ logits = logits[:,local_i,local_j,:]
+
+ logits = logits/temperature
+
+ if top_k is not None:
+ logits = model.top_k_logits(logits, top_k)
+ # apply softmax to convert to probabilities
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ # sample from the distribution or take the most likely
+ if sample:
+ ix = torch.multinomial(probs, num_samples=1)
+ else:
+ _, ix = torch.topk(probs, k=1, dim=-1)
+ idx[:,i,j] = ix
+
+ if (i*cshape[3]+j)%update_every==0:
+ xstart = model.decode_to_img(idx[:, :cshape[2], :cshape[3]], cshape,)
+
+ xstart = bchw_to_st(xstart)
+ output.image(xstart, clamp=True, output_format="PNG")
+
+ if animate:
+ writer.append_data((xstart[0]*255).clip(0, 255).astype(np.uint8))
+
+ xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
+ xstart = bchw_to_st(xstart)
+ output.image(xstart, clamp=True, output_format="PNG")
+ #save_img(xstart, "full_res_sample.png")
+ if animate:
+ writer.close()
+ st.video(outvid)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-r",
+ "--resume",
+ type=str,
+ nargs="?",
+ help="load from logdir or checkpoint in logdir",
+ )
+ parser.add_argument(
+ "-b",
+ "--base",
+ nargs="*",
+ metavar="base_config.yaml",
+ help="paths to base configs. Loaded from left-to-right. "
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
+ default=list(),
+ )
+ parser.add_argument(
+ "-c",
+ "--config",
+ nargs="?",
+ metavar="single_config.yaml",
+ help="path to single config. If specified, base configs will be ignored "
+ "(except for the last one if left unspecified).",
+ const=True,
+ default="",
+ )
+ parser.add_argument(
+ "--ignore_base_data",
+ action="store_true",
+ help="Ignore data specification from base configs. Useful if you want "
+ "to specify a custom datasets on the command line.",
+ )
+ return parser
+
+
+def load_model_from_config(config, sd, gpu=True, eval_mode=True):
+ if "ckpt_path" in config.params:
+ st.warning("Deleting the restore-ckpt path from the config...")
+ config.params.ckpt_path = None
+ if "downsample_cond_size" in config.params:
+ st.warning("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
+ config.params.downsample_cond_size = -1
+ config.params["downsample_cond_factor"] = 0.5
+ try:
+ if "ckpt_path" in config.params.first_stage_config.params:
+ config.params.first_stage_config.params.ckpt_path = None
+ st.warning("Deleting the first-stage restore-ckpt path from the config...")
+ if "ckpt_path" in config.params.cond_stage_config.params:
+ config.params.cond_stage_config.params.ckpt_path = None
+ st.warning("Deleting the cond-stage restore-ckpt path from the config...")
+ except:
+ pass
+
+ model = instantiate_from_config(config)
+ if sd is not None:
+ missing, unexpected = model.load_state_dict(sd, strict=False)
+ st.info(f"Missing Keys in State Dict: {missing}")
+ st.info(f"Unexpected Keys in State Dict: {unexpected}")
+ if gpu:
+ model.cuda()
+ if eval_mode:
+ model.eval()
+ return {"model": model}
+
+
+def get_data(config):
+ # get data
+ data = instantiate_from_config(config.data)
+ data.prepare_data()
+ data.setup()
+ return data
+
+
+@st.cache(allow_output_mutation=True, suppress_st_warning=True)
+def load_model_and_dset(config, ckpt, gpu, eval_mode):
+ # get data
+ dsets = get_data(config) # calls data.config ...
+
+ # now load the specified checkpoint
+ if ckpt:
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ global_step = pl_sd["global_step"]
+ else:
+ pl_sd = {"state_dict": None}
+ global_step = None
+ model = load_model_from_config(config.model,
+ pl_sd["state_dict"],
+ gpu=gpu,
+ eval_mode=eval_mode)["model"]
+ return dsets, model, global_step
+
+
+if __name__ == "__main__":
+ sys.path.append(os.getcwd())
+
+ parser = get_parser()
+
+ opt, unknown = parser.parse_known_args()
+
+ ckpt = None
+ if opt.resume:
+ if not os.path.exists(opt.resume):
+ raise ValueError("Cannot find {}".format(opt.resume))
+ if os.path.isfile(opt.resume):
+ paths = opt.resume.split("/")
+ try:
+ idx = len(paths)-paths[::-1].index("logs")+1
+ except ValueError:
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
+ logdir = "/".join(paths[:idx])
+ ckpt = opt.resume
+ else:
+ assert os.path.isdir(opt.resume), opt.resume
+ logdir = opt.resume.rstrip("/")
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
+ print(f"logdir:{logdir}")
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
+ opt.base = base_configs+opt.base
+
+ if opt.config:
+ if type(opt.config) == str:
+ opt.base = [opt.config]
+ else:
+ opt.base = [opt.base[-1]]
+
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
+ cli = OmegaConf.from_dotlist(unknown)
+ if opt.ignore_base_data:
+ for config in configs:
+ if hasattr(config, "data"): del config["data"]
+ config = OmegaConf.merge(*configs, cli)
+
+ st.sidebar.text(ckpt)
+ gs = st.sidebar.empty()
+ gs.text(f"Global step: ?")
+ st.sidebar.text("Options")
+ #gpu = st.sidebar.checkbox("GPU", value=True)
+ gpu = True
+ #eval_mode = st.sidebar.checkbox("Eval Mode", value=True)
+ eval_mode = True
+ #show_config = st.sidebar.checkbox("Show Config", value=False)
+ show_config = False
+ if show_config:
+ st.info("Checkpoint: {}".format(ckpt))
+ st.json(OmegaConf.to_container(config))
+
+ dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
+ gs.text(f"Global step: {global_step}")
+ run_conditional(model, dsets)
diff --git a/repositories/taming-transformers/scripts/sample_fast.py b/repositories/taming-transformers/scripts/sample_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff546c7dcbe459807ac3b70f834ccc1082fe8b4e
--- /dev/null
+++ b/repositories/taming-transformers/scripts/sample_fast.py
@@ -0,0 +1,260 @@
+import argparse, os, sys, glob
+import torch
+import time
+import numpy as np
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from einops import repeat
+
+from main import instantiate_from_config
+from taming.modules.transformer.mingpt import sample_with_past
+
+
+rescale = lambda x: (x + 1.) / 2.
+
+
+def chw_to_pillow(x):
+ return Image.fromarray((255*rescale(x.detach().cpu().numpy().transpose(1,2,0))).clip(0,255).astype(np.uint8))
+
+
+@torch.no_grad()
+def sample_classconditional(model, batch_size, class_label, steps=256, temperature=None, top_k=None, callback=None,
+ dim_z=256, h=16, w=16, verbose_time=False, top_p=None):
+ log = dict()
+ assert type(class_label) == int, f'expecting type int but type is {type(class_label)}'
+ qzshape = [batch_size, dim_z, h, w]
+ assert not model.be_unconditional, 'Expecting a class-conditional Net2NetTransformer.'
+ c_indices = repeat(torch.tensor([class_label]), '1 -> b 1', b=batch_size).to(model.device) # class token
+ t1 = time.time()
+ index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
+ sample_logits=True, top_k=top_k, callback=callback,
+ temperature=temperature, top_p=top_p)
+ if verbose_time:
+ sampling_time = time.time() - t1
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
+ x_sample = model.decode_to_img(index_sample, qzshape)
+ log["samples"] = x_sample
+ log["class_label"] = c_indices
+ return log
+
+
+@torch.no_grad()
+def sample_unconditional(model, batch_size, steps=256, temperature=None, top_k=None, top_p=None, callback=None,
+ dim_z=256, h=16, w=16, verbose_time=False):
+ log = dict()
+ qzshape = [batch_size, dim_z, h, w]
+ assert model.be_unconditional, 'Expecting an unconditional model.'
+ c_indices = repeat(torch.tensor([model.sos_token]), '1 -> b 1', b=batch_size).to(model.device) # sos token
+ t1 = time.time()
+ index_sample = sample_with_past(c_indices, model.transformer, steps=steps,
+ sample_logits=True, top_k=top_k, callback=callback,
+ temperature=temperature, top_p=top_p)
+ if verbose_time:
+ sampling_time = time.time() - t1
+ print(f"Full sampling takes about {sampling_time:.2f} seconds.")
+ x_sample = model.decode_to_img(index_sample, qzshape)
+ log["samples"] = x_sample
+ return log
+
+
+@torch.no_grad()
+def run(logdir, model, batch_size, temperature, top_k, unconditional=True, num_samples=50000,
+ given_classes=None, top_p=None):
+ batches = [batch_size for _ in range(num_samples//batch_size)] + [num_samples % batch_size]
+ if not unconditional:
+ assert given_classes is not None
+ print("Running in pure class-conditional sampling mode. I will produce "
+ f"{num_samples} samples for each of the {len(given_classes)} classes, "
+ f"i.e. {num_samples*len(given_classes)} in total.")
+ for class_label in tqdm(given_classes, desc="Classes"):
+ for n, bs in tqdm(enumerate(batches), desc="Sampling Class"):
+ if bs == 0: break
+ logs = sample_classconditional(model, batch_size=bs, class_label=class_label,
+ temperature=temperature, top_k=top_k, top_p=top_p)
+ save_from_logs(logs, logdir, base_count=n * batch_size, cond_key=logs["class_label"])
+ else:
+ print(f"Running in unconditional sampling mode, producing {num_samples} samples.")
+ for n, bs in tqdm(enumerate(batches), desc="Sampling"):
+ if bs == 0: break
+ logs = sample_unconditional(model, batch_size=bs, temperature=temperature, top_k=top_k, top_p=top_p)
+ save_from_logs(logs, logdir, base_count=n * batch_size)
+
+
+def save_from_logs(logs, logdir, base_count, key="samples", cond_key=None):
+ xx = logs[key]
+ for i, x in enumerate(xx):
+ x = chw_to_pillow(x)
+ count = base_count + i
+ if cond_key is None:
+ x.save(os.path.join(logdir, f"{count:06}.png"))
+ else:
+ condlabel = cond_key[i]
+ if type(condlabel) == torch.Tensor: condlabel = condlabel.item()
+ os.makedirs(os.path.join(logdir, str(condlabel)), exist_ok=True)
+ x.save(os.path.join(logdir, str(condlabel), f"{count:06}.png"))
+
+
+def get_parser():
+ def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-r",
+ "--resume",
+ type=str,
+ nargs="?",
+ help="load from logdir or checkpoint in logdir",
+ )
+ parser.add_argument(
+ "-o",
+ "--outdir",
+ type=str,
+ nargs="?",
+ help="path where the samples will be logged to.",
+ default=""
+ )
+ parser.add_argument(
+ "-b",
+ "--base",
+ nargs="*",
+ metavar="base_config.yaml",
+ help="paths to base configs. Loaded from left-to-right. "
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
+ default=list(),
+ )
+ parser.add_argument(
+ "-n",
+ "--num_samples",
+ type=int,
+ nargs="?",
+ help="num_samples to draw",
+ default=50000
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ nargs="?",
+ help="the batch size",
+ default=25
+ )
+ parser.add_argument(
+ "-k",
+ "--top_k",
+ type=int,
+ nargs="?",
+ help="top-k value to sample with",
+ default=250,
+ )
+ parser.add_argument(
+ "-t",
+ "--temperature",
+ type=float,
+ nargs="?",
+ help="temperature value to sample with",
+ default=1.0
+ )
+ parser.add_argument(
+ "-p",
+ "--top_p",
+ type=float,
+ nargs="?",
+ help="top-p value to sample with",
+ default=1.0
+ )
+ parser.add_argument(
+ "--classes",
+ type=str,
+ nargs="?",
+ help="specify comma-separated classes to sample from. Uses 1000 classes per default.",
+ default="imagenet"
+ )
+ return parser
+
+
+def load_model_from_config(config, sd, gpu=True, eval_mode=True):
+ model = instantiate_from_config(config)
+ if sd is not None:
+ model.load_state_dict(sd)
+ if gpu:
+ model.cuda()
+ if eval_mode:
+ model.eval()
+ return {"model": model}
+
+
+def load_model(config, ckpt, gpu, eval_mode):
+ # load the specified checkpoint
+ if ckpt:
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ global_step = pl_sd["global_step"]
+ print(f"loaded model from global step {global_step}.")
+ else:
+ pl_sd = {"state_dict": None}
+ global_step = None
+ model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
+ return model, global_step
+
+
+if __name__ == "__main__":
+ sys.path.append(os.getcwd())
+ parser = get_parser()
+
+ opt, unknown = parser.parse_known_args()
+ assert opt.resume
+
+ ckpt = None
+
+ if not os.path.exists(opt.resume):
+ raise ValueError("Cannot find {}".format(opt.resume))
+ if os.path.isfile(opt.resume):
+ paths = opt.resume.split("/")
+ try:
+ idx = len(paths)-paths[::-1].index("logs")+1
+ except ValueError:
+ idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
+ logdir = "/".join(paths[:idx])
+ ckpt = opt.resume
+ else:
+ assert os.path.isdir(opt.resume), opt.resume
+ logdir = opt.resume.rstrip("/")
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
+
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
+ opt.base = base_configs+opt.base
+
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
+ cli = OmegaConf.from_dotlist(unknown)
+ config = OmegaConf.merge(*configs, cli)
+
+ model, global_step = load_model(config, ckpt, gpu=True, eval_mode=True)
+
+ if opt.outdir:
+ print(f"Switching logdir from '{logdir}' to '{opt.outdir}'")
+ logdir = opt.outdir
+
+ if opt.classes == "imagenet":
+ given_classes = [i for i in range(1000)]
+ else:
+ cls_str = opt.classes
+ assert not cls_str.endswith(","), 'class string should not end with a ","'
+ given_classes = [int(c) for c in cls_str.split(",")]
+
+ logdir = os.path.join(logdir, "samples", f"top_k_{opt.top_k}_temp_{opt.temperature:.2f}_top_p_{opt.top_p}",
+ f"{global_step}")
+
+ print(f"Logging to {logdir}")
+ os.makedirs(logdir, exist_ok=True)
+
+ run(logdir, model, opt.batch_size, opt.temperature, opt.top_k, unconditional=model.be_unconditional,
+ given_classes=given_classes, num_samples=opt.num_samples, top_p=opt.top_p)
+
+ print("done.")
diff --git a/repositories/taming-transformers/scripts/taming-transformers.ipynb b/repositories/taming-transformers/scripts/taming-transformers.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..8e0b8cb4721a64ad54bb4b7ee14c871150c35767
--- /dev/null
+++ b/repositories/taming-transformers/scripts/taming-transformers.ipynb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8bb427db9bb669c6842a6de32e139408bf85e73be878fb1b2b16068975675a4e
+size 3581166
diff --git a/repositories/taming-transformers/setup.py b/repositories/taming-transformers/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..a220d12b21d96c5093a218c406cf47f1e7c8761a
--- /dev/null
+++ b/repositories/taming-transformers/setup.py
@@ -0,0 +1,13 @@
+from setuptools import setup, find_packages
+
+setup(
+ name='taming-transformers',
+ version='0.0.1',
+ description='Taming Transformers for High-Resolution Image Synthesis',
+ packages=find_packages(),
+ install_requires=[
+ 'torch',
+ 'numpy',
+ 'tqdm',
+ ],
+)
diff --git a/repositories/taming-transformers/taming/data/ade20k.py b/repositories/taming-transformers/taming/data/ade20k.py
new file mode 100644
index 0000000000000000000000000000000000000000..366dae97207dbb8356598d636e14ad084d45bc76
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/ade20k.py
@@ -0,0 +1,124 @@
+import os
+import numpy as np
+import cv2
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset
+
+from taming.data.sflckr import SegmentationBase # for examples included in repo
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/ade20k_examples.txt",
+ data_root="data/ade20k_images",
+ segmentation_root="data/ade20k_segmentations",
+ size=size, random_crop=random_crop,
+ interpolation=interpolation,
+ n_labels=151, shift_segmentation=False)
+
+
+# With semantic map and scene label
+class ADE20kBase(Dataset):
+ def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
+ self.split = self.get_split()
+ self.n_labels = 151 # unknown + 150
+ self.data_csv = {"train": "data/ade20k_train.txt",
+ "validation": "data/ade20k_test.txt"}[self.split]
+ self.data_root = "data/ade20k_root"
+ with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
+ self.scene_categories = f.read().splitlines()
+ self.scene_categories = dict(line.split() for line in self.scene_categories)
+ with open(self.data_csv, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, "images", l)
+ for l in self.image_paths],
+ "relative_segmentation_path_": [l.replace(".jpg", ".png")
+ for l in self.image_paths],
+ "segmentation_path_": [os.path.join(self.data_root, "annotations",
+ l.replace(".jpg", ".png"))
+ for l in self.image_paths],
+ "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
+ for l in self.image_paths],
+ }
+
+ size = None if size is not None and size<=0 else size
+ self.size = size
+ if crop_size is None:
+ self.crop_size = size if size is not None else None
+ else:
+ self.crop_size = crop_size
+ if self.size is not None:
+ self.interpolation = interpolation
+ self.interpolation = {
+ "nearest": cv2.INTER_NEAREST,
+ "bilinear": cv2.INTER_LINEAR,
+ "bicubic": cv2.INTER_CUBIC,
+ "area": cv2.INTER_AREA,
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=self.interpolation)
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=cv2.INTER_NEAREST)
+
+ if crop_size is not None:
+ self.center_crop = not random_crop
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
+ self.preprocessor = self.cropper
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ if self.size is not None:
+ image = self.image_rescaler(image=image)["image"]
+ segmentation = Image.open(example["segmentation_path_"])
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.size is not None:
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
+ if self.size is not None:
+ processed = self.preprocessor(image=image, mask=segmentation)
+ else:
+ processed = {"image": image, "mask": segmentation}
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
+ segmentation = processed["mask"]
+ onehot = np.eye(self.n_labels)[segmentation]
+ example["segmentation"] = onehot
+ return example
+
+
+class ADE20kTrain(ADE20kBase):
+ # default to random_crop=True
+ def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ interpolation=interpolation, crop_size=crop_size)
+
+ def get_split(self):
+ return "train"
+
+
+class ADE20kValidation(ADE20kBase):
+ def get_split(self):
+ return "validation"
+
+
+if __name__ == "__main__":
+ dset = ADE20kValidation()
+ ex = dset[0]
+ for k in ["image", "scene_category", "segmentation"]:
+ print(type(ex[k]))
+ try:
+ print(ex[k].shape)
+ except:
+ print(ex[k])
diff --git a/repositories/taming-transformers/taming/data/annotated_objects_coco.py b/repositories/taming-transformers/taming/data/annotated_objects_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..af000ecd943d7b8a85d7eb70195c9ecd10ab5edc
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/annotated_objects_coco.py
@@ -0,0 +1,139 @@
+import json
+from itertools import chain
+from pathlib import Path
+from typing import Iterable, Dict, List, Callable, Any
+from collections import defaultdict
+
+from tqdm import tqdm
+
+from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
+from taming.data.helper_types import Annotation, ImageDescription, Category
+
+COCO_PATH_STRUCTURE = {
+ 'train': {
+ 'top_level': '',
+ 'instances_annotations': 'annotations/instances_train2017.json',
+ 'stuff_annotations': 'annotations/stuff_train2017.json',
+ 'files': 'train2017'
+ },
+ 'validation': {
+ 'top_level': '',
+ 'instances_annotations': 'annotations/instances_val2017.json',
+ 'stuff_annotations': 'annotations/stuff_val2017.json',
+ 'files': 'val2017'
+ }
+}
+
+
+def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
+ return {
+ str(img['id']): ImageDescription(
+ id=img['id'],
+ license=img.get('license'),
+ file_name=img['file_name'],
+ coco_url=img['coco_url'],
+ original_size=(img['width'], img['height']),
+ date_captured=img.get('date_captured'),
+ flickr_url=img.get('flickr_url')
+ )
+ for img in description_json
+ }
+
+
+def load_categories(category_json: Iterable) -> Dict[str, Category]:
+ return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
+ for cat in category_json if cat['name'] != 'other'}
+
+
+def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
+ category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
+ annotations = defaultdict(list)
+ total = sum(len(a) for a in annotations_json)
+ for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
+ image_id = str(ann['image_id'])
+ if image_id not in image_descriptions:
+ raise ValueError(f'image_id [{image_id}] has no image description.')
+ category_id = ann['category_id']
+ try:
+ category_no = category_no_for_id(str(category_id))
+ except KeyError:
+ continue
+
+ width, height = image_descriptions[image_id].original_size
+ bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
+
+ annotations[image_id].append(
+ Annotation(
+ id=ann['id'],
+ area=bbox[2]*bbox[3], # use bbox area
+ is_group_of=ann['iscrowd'],
+ image_id=ann['image_id'],
+ bbox=bbox,
+ category_id=str(category_id),
+ category_no=category_no
+ )
+ )
+ return dict(annotations)
+
+
+class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
+ def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
+ """
+ @param data_path: is the path to the following folder structure:
+ coco/
+ ├── annotations
+ │ ├── instances_train2017.json
+ │ ├── instances_val2017.json
+ │ ├── stuff_train2017.json
+ │ └── stuff_val2017.json
+ ├── train2017
+ │ ├── 000000000009.jpg
+ │ ├── 000000000025.jpg
+ │ └── ...
+ ├── val2017
+ │ ├── 000000000139.jpg
+ │ ├── 000000000285.jpg
+ │ └── ...
+ @param: split: one of 'train' or 'validation'
+ @param: desired image size (give square images)
+ """
+ super().__init__(**kwargs)
+ self.use_things = use_things
+ self.use_stuff = use_stuff
+
+ with open(self.paths['instances_annotations']) as f:
+ inst_data_json = json.load(f)
+ with open(self.paths['stuff_annotations']) as f:
+ stuff_data_json = json.load(f)
+
+ category_jsons = []
+ annotation_jsons = []
+ if self.use_things:
+ category_jsons.append(inst_data_json['categories'])
+ annotation_jsons.append(inst_data_json['annotations'])
+ if self.use_stuff:
+ category_jsons.append(stuff_data_json['categories'])
+ annotation_jsons.append(stuff_data_json['annotations'])
+
+ self.categories = load_categories(chain(*category_jsons))
+ self.filter_categories()
+ self.setup_category_id_and_number()
+
+ self.image_descriptions = load_image_descriptions(inst_data_json['images'])
+ annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
+ self.annotations = self.filter_object_number(annotations, self.min_object_area,
+ self.min_objects_per_image, self.max_objects_per_image)
+ self.image_ids = list(self.annotations.keys())
+ self.clean_up_annotations_and_image_descriptions()
+
+ def get_path_structure(self) -> Dict[str, str]:
+ if self.split not in COCO_PATH_STRUCTURE:
+ raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
+ return COCO_PATH_STRUCTURE[self.split]
+
+ def get_image_path(self, image_id: str) -> Path:
+ return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ # noinspection PyProtectedMember
+ return self.image_descriptions[image_id]._asdict()
diff --git a/repositories/taming-transformers/taming/data/annotated_objects_dataset.py b/repositories/taming-transformers/taming/data/annotated_objects_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..53cc346a1c76289a4964d7dc8a29582172f33dc0
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/annotated_objects_dataset.py
@@ -0,0 +1,218 @@
+from pathlib import Path
+from typing import Optional, List, Callable, Dict, Any, Union
+import warnings
+
+import PIL.Image as pil_image
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
+from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
+from taming.data.conditional_builder.utils import load_object_from_string
+from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
+from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
+ Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
+
+
+class AnnotatedObjectsDataset(Dataset):
+ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
+ min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
+ crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
+ encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
+ no_object_classes: Optional[int] = None):
+ self.data_path = data_path
+ self.split = split
+ self.keys = keys
+ self.target_image_size = target_image_size
+ self.min_object_area = min_object_area
+ self.min_objects_per_image = min_objects_per_image
+ self.max_objects_per_image = max_objects_per_image
+ self.crop_method = crop_method
+ self.random_flip = random_flip
+ self.no_tokens = no_tokens
+ self.use_group_parameter = use_group_parameter
+ self.encode_crop = encode_crop
+
+ self.annotations = None
+ self.image_descriptions = None
+ self.categories = None
+ self.category_ids = None
+ self.category_number = None
+ self.image_ids = None
+ self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
+ self.paths = self.build_paths(self.data_path)
+ self._conditional_builders = None
+ self.category_allow_list = None
+ if category_allow_list_target:
+ allow_list = load_object_from_string(category_allow_list_target)
+ self.category_allow_list = {name for name, _ in allow_list}
+ self.category_mapping = {}
+ if category_mapping_target:
+ self.category_mapping = load_object_from_string(category_mapping_target)
+ self.no_object_classes = no_object_classes
+
+ def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
+ top_level = Path(top_level)
+ sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
+ for path in sub_paths.values():
+ if not path.exists():
+ raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
+ return sub_paths
+
+ @staticmethod
+ def load_image_from_disk(path: Path) -> Image:
+ return pil_image.open(path).convert('RGB')
+
+ @staticmethod
+ def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
+ transform_functions = []
+ if crop_method == 'none':
+ transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
+ elif crop_method == 'center':
+ transform_functions.extend([
+ transforms.Resize(target_image_size),
+ CenterCropReturnCoordinates(target_image_size)
+ ])
+ elif crop_method == 'random-1d':
+ transform_functions.extend([
+ transforms.Resize(target_image_size),
+ RandomCrop1dReturnCoordinates(target_image_size)
+ ])
+ elif crop_method == 'random-2d':
+ transform_functions.extend([
+ Random2dCropReturnCoordinates(target_image_size),
+ transforms.Resize(target_image_size)
+ ])
+ elif crop_method is None:
+ return None
+ else:
+ raise ValueError(f'Received invalid crop method [{crop_method}].')
+ if random_flip:
+ transform_functions.append(RandomHorizontalFlipReturn())
+ transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
+ return transform_functions
+
+ def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
+ crop_bbox = None
+ flipped = None
+ for t in self.transform_functions:
+ if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
+ crop_bbox, x = t(x)
+ elif isinstance(t, RandomHorizontalFlipReturn):
+ flipped, x = t(x)
+ else:
+ x = t(x)
+ return crop_bbox, flipped, x
+
+ @property
+ def no_classes(self) -> int:
+ return self.no_object_classes if self.no_object_classes else len(self.categories)
+
+ @property
+ def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
+ # cannot set this up in init because no_classes is only known after loading data in init of superclass
+ if self._conditional_builders is None:
+ self._conditional_builders = {
+ 'objects_center_points': ObjectsCenterPointsConditionalBuilder(
+ self.no_classes,
+ self.max_objects_per_image,
+ self.no_tokens,
+ self.encode_crop,
+ self.use_group_parameter,
+ getattr(self, 'use_additional_parameters', False)
+ ),
+ 'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
+ self.no_classes,
+ self.max_objects_per_image,
+ self.no_tokens,
+ self.encode_crop,
+ self.use_group_parameter,
+ getattr(self, 'use_additional_parameters', False)
+ )
+ }
+ return self._conditional_builders
+
+ def filter_categories(self) -> None:
+ if self.category_allow_list:
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
+ if self.category_mapping:
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
+
+ def setup_category_id_and_number(self) -> None:
+ self.category_ids = list(self.categories.keys())
+ self.category_ids.sort()
+ if '/m/01s55n' in self.category_ids:
+ self.category_ids.remove('/m/01s55n')
+ self.category_ids.append('/m/01s55n')
+ self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
+ if self.category_allow_list is not None and self.category_mapping is None \
+ and len(self.category_ids) != len(self.category_allow_list):
+ warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
+ 'Make sure all names in category_allow_list exist.')
+
+ def clean_up_annotations_and_image_descriptions(self) -> None:
+ image_id_set = set(self.image_ids)
+ self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
+ self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
+
+ @staticmethod
+ def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
+ min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
+ filtered = {}
+ for image_id, annotations in all_annotations.items():
+ annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
+ if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
+ filtered[image_id] = annotations_with_min_area
+ return filtered
+
+ def __len__(self):
+ return len(self.image_ids)
+
+ def __getitem__(self, n: int) -> Dict[str, Any]:
+ image_id = self.get_image_id(n)
+ sample = self.get_image_description(image_id)
+ sample['annotations'] = self.get_annotation(image_id)
+
+ if 'image' in self.keys:
+ sample['image_path'] = str(self.get_image_path(image_id))
+ sample['image'] = self.load_image_from_disk(sample['image_path'])
+ sample['image'] = convert_pil_to_tensor(sample['image'])
+ sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
+ sample['image'] = sample['image'].permute(1, 2, 0)
+
+ for conditional, builder in self.conditional_builders.items():
+ if conditional in self.keys:
+ sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
+
+ if self.keys:
+ # only return specified keys
+ sample = {key: sample[key] for key in self.keys}
+ return sample
+
+ def get_image_id(self, no: int) -> str:
+ return self.image_ids[no]
+
+ def get_annotation(self, image_id: str) -> str:
+ return self.annotations[image_id]
+
+ def get_textual_label_for_category_id(self, category_id: str) -> str:
+ return self.categories[category_id].name
+
+ def get_textual_label_for_category_no(self, category_no: int) -> str:
+ return self.categories[self.get_category_id(category_no)].name
+
+ def get_category_number(self, category_id: str) -> int:
+ return self.category_number[category_id]
+
+ def get_category_id(self, category_no: int) -> str:
+ return self.category_ids[category_no]
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ raise NotImplementedError()
+
+ def get_path_structure(self):
+ raise NotImplementedError
+
+ def get_image_path(self, image_id: str) -> Path:
+ raise NotImplementedError
diff --git a/repositories/taming-transformers/taming/data/annotated_objects_open_images.py b/repositories/taming-transformers/taming/data/annotated_objects_open_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..aede6803d2cef7a74ca784e7907d35fba6c71239
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/annotated_objects_open_images.py
@@ -0,0 +1,137 @@
+from collections import defaultdict
+from csv import DictReader, reader as TupleReader
+from pathlib import Path
+from typing import Dict, List, Any
+import warnings
+
+from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
+from taming.data.helper_types import Annotation, Category
+from tqdm import tqdm
+
+OPEN_IMAGES_STRUCTURE = {
+ 'train': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'oidv6-train-annotations-bbox.csv',
+ 'file_list': 'train-images-boxable.csv',
+ 'files': 'train'
+ },
+ 'validation': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'validation-annotations-bbox.csv',
+ 'file_list': 'validation-images.csv',
+ 'files': 'validation'
+ },
+ 'test': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'test-annotations-bbox.csv',
+ 'file_list': 'test-images.csv',
+ 'files': 'test'
+ }
+}
+
+
+def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
+ category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
+ annotations: Dict[str, List[Annotation]] = defaultdict(list)
+ with open(descriptor_path) as file:
+ reader = DictReader(file)
+ for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
+ width = float(row['XMax']) - float(row['XMin'])
+ height = float(row['YMax']) - float(row['YMin'])
+ area = width * height
+ category_id = row['LabelName']
+ if category_id in category_mapping:
+ category_id = category_mapping[category_id]
+ if area >= min_object_area and category_id in category_no_for_id:
+ annotations[row['ImageID']].append(
+ Annotation(
+ id=i,
+ image_id=row['ImageID'],
+ source=row['Source'],
+ category_id=category_id,
+ category_no=category_no_for_id[category_id],
+ confidence=float(row['Confidence']),
+ bbox=(float(row['XMin']), float(row['YMin']), width, height),
+ area=area,
+ is_occluded=bool(int(row['IsOccluded'])),
+ is_truncated=bool(int(row['IsTruncated'])),
+ is_group_of=bool(int(row['IsGroupOf'])),
+ is_depiction=bool(int(row['IsDepiction'])),
+ is_inside=bool(int(row['IsInside']))
+ )
+ )
+ if 'train' in str(descriptor_path) and i < 14000000:
+ warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
+ return dict(annotations)
+
+
+def load_image_ids(csv_path: Path) -> List[str]:
+ with open(csv_path) as file:
+ reader = DictReader(file)
+ return [row['image_name'] for row in reader]
+
+
+def load_categories(csv_path: Path) -> Dict[str, Category]:
+ with open(csv_path) as file:
+ reader = TupleReader(file)
+ return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
+
+
+class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
+ def __init__(self, use_additional_parameters: bool, **kwargs):
+ """
+ @param data_path: is the path to the following folder structure:
+ open_images/
+ │ oidv6-train-annotations-bbox.csv
+ ├── class-descriptions-boxable.csv
+ ├── oidv6-train-annotations-bbox.csv
+ ├── test
+ │ ├── 000026e7ee790996.jpg
+ │ ├── 000062a39995e348.jpg
+ │ └── ...
+ ├── test-annotations-bbox.csv
+ ├── test-images.csv
+ ├── train
+ │ ├── 000002b66c9c498e.jpg
+ │ ├── 000002b97e5471a0.jpg
+ │ └── ...
+ ├── train-images-boxable.csv
+ ├── validation
+ │ ├── 0001eeaf4aed83f9.jpg
+ │ ├── 0004886b7d043cfd.jpg
+ │ └── ...
+ ├── validation-annotations-bbox.csv
+ └── validation-images.csv
+ @param: split: one of 'train', 'validation' or 'test'
+ @param: desired image size (returns square images)
+ """
+
+ super().__init__(**kwargs)
+ self.use_additional_parameters = use_additional_parameters
+
+ self.categories = load_categories(self.paths['class_descriptions'])
+ self.filter_categories()
+ self.setup_category_id_and_number()
+
+ self.image_descriptions = {}
+ annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
+ self.category_number)
+ self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
+ self.max_objects_per_image)
+ self.image_ids = list(self.annotations.keys())
+ self.clean_up_annotations_and_image_descriptions()
+
+ def get_path_structure(self) -> Dict[str, str]:
+ if self.split not in OPEN_IMAGES_STRUCTURE:
+ raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
+ return OPEN_IMAGES_STRUCTURE[self.split]
+
+ def get_image_path(self, image_id: str) -> Path:
+ return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ image_path = self.get_image_path(image_id)
+ return {'file_path': str(image_path), 'file_name': image_path.name}
diff --git a/repositories/taming-transformers/taming/data/base.py b/repositories/taming-transformers/taming/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e21667df4ce4baa6bb6aad9f8679bd756e2ffdb7
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/base.py
@@ -0,0 +1,70 @@
+import bisect
+import numpy as np
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset, ConcatDataset
+
+
+class ConcatDatasetWithIndex(ConcatDataset):
+ """Modified from original pytorch code to return dataset idx"""
+ def __getitem__(self, idx):
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError("absolute value of index should not exceed dataset length")
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx][sample_idx], dataset_idx
+
+
+class ImagePaths(Dataset):
+ def __init__(self, paths, size=None, random_crop=False, labels=None):
+ self.size = size
+ self.random_crop = random_crop
+
+ self.labels = dict() if labels is None else labels
+ self.labels["file_path_"] = paths
+ self._length = len(paths)
+
+ if self.size is not None and self.size > 0:
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
+ if not self.random_crop:
+ self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
+ self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
+ else:
+ self.preprocessor = lambda **kwargs: kwargs
+
+ def __len__(self):
+ return self._length
+
+ def preprocess_image(self, image_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
+
+ def __getitem__(self, i):
+ example = dict()
+ example["image"] = self.preprocess_image(self.labels["file_path_"][i])
+ for k in self.labels:
+ example[k] = self.labels[k][i]
+ return example
+
+
+class NumpyPaths(ImagePaths):
+ def preprocess_image(self, image_path):
+ image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
+ image = np.transpose(image, (1,2,0))
+ image = Image.fromarray(image, mode="RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
diff --git a/repositories/taming-transformers/taming/data/coco.py b/repositories/taming-transformers/taming/data/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b2f7838448cb63dcf96daffe9470d58566d975a
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/coco.py
@@ -0,0 +1,176 @@
+import os
+import json
+import albumentations
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from torch.utils.data import Dataset
+
+from taming.data.sflckr import SegmentationBase # for examples included in repo
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/coco_examples.txt",
+ data_root="data/coco_images",
+ segmentation_root="data/coco_segmentations",
+ size=size, random_crop=random_crop,
+ interpolation=interpolation,
+ n_labels=183, shift_segmentation=True)
+
+
+class CocoBase(Dataset):
+ """needed for (image, caption, segmentation) pairs"""
+ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
+ crop_size=None, force_no_crop=False, given_files=None):
+ self.split = self.get_split()
+ self.size = size
+ if crop_size is None:
+ self.crop_size = size
+ else:
+ self.crop_size = crop_size
+
+ self.onehot = onehot_segmentation # return segmentation as rgb or one hot
+ self.stuffthing = use_stuffthing # include thing in segmentation
+ if self.onehot and not self.stuffthing:
+ raise NotImplemented("One hot mode is only supported for the "
+ "stuffthings version because labels are stored "
+ "a bit different.")
+
+ data_json = datajson
+ with open(data_json) as json_file:
+ self.json_data = json.load(json_file)
+ self.img_id_to_captions = dict()
+ self.img_id_to_filepath = dict()
+ self.img_id_to_segmentation_filepath = dict()
+
+ assert data_json.split("/")[-1] in ["captions_train2017.json",
+ "captions_val2017.json"]
+ if self.stuffthing:
+ self.segmentation_prefix = (
+ "data/cocostuffthings/val2017" if
+ data_json.endswith("captions_val2017.json") else
+ "data/cocostuffthings/train2017")
+ else:
+ self.segmentation_prefix = (
+ "data/coco/annotations/stuff_val2017_pixelmaps" if
+ data_json.endswith("captions_val2017.json") else
+ "data/coco/annotations/stuff_train2017_pixelmaps")
+
+ imagedirs = self.json_data["images"]
+ self.labels = {"image_ids": list()}
+ for imgdir in tqdm(imagedirs, desc="ImgToPath"):
+ self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
+ self.img_id_to_captions[imgdir["id"]] = list()
+ pngfilename = imgdir["file_name"].replace("jpg", "png")
+ self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
+ self.segmentation_prefix, pngfilename)
+ if given_files is not None:
+ if pngfilename in given_files:
+ self.labels["image_ids"].append(imgdir["id"])
+ else:
+ self.labels["image_ids"].append(imgdir["id"])
+
+ capdirs = self.json_data["annotations"]
+ for capdir in tqdm(capdirs, desc="ImgToCaptions"):
+ # there are in average 5 captions per image
+ self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
+
+ self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
+ if self.split=="validation":
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
+ self.preprocessor = albumentations.Compose(
+ [self.rescaler, self.cropper],
+ additional_targets={"segmentation": "image"})
+ if force_no_crop:
+ self.rescaler = albumentations.Resize(height=self.size, width=self.size)
+ self.preprocessor = albumentations.Compose(
+ [self.rescaler],
+ additional_targets={"segmentation": "image"})
+
+ def __len__(self):
+ return len(self.labels["image_ids"])
+
+ def preprocess_image(self, image_path, segmentation_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+
+ segmentation = Image.open(segmentation_path)
+ if not self.onehot and not segmentation.mode == "RGB":
+ segmentation = segmentation.convert("RGB")
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.onehot:
+ assert self.stuffthing
+ # stored in caffe format: unlabeled==255. stuff and thing from
+ # 0-181. to be compatible with the labels in
+ # https://github.com/nightrome/cocostuff/blob/master/labels.txt
+ # we shift stuffthing one to the right and put unlabeled in zero
+ # as long as segmentation is uint8 shifting to right handles the
+ # latter too
+ assert segmentation.dtype == np.uint8
+ segmentation = segmentation + 1
+
+ processed = self.preprocessor(image=image, segmentation=segmentation)
+ image, segmentation = processed["image"], processed["segmentation"]
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ if self.onehot:
+ assert segmentation.dtype == np.uint8
+ # make it one hot
+ n_labels = 183
+ flatseg = np.ravel(segmentation)
+ onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
+ onehot[np.arange(flatseg.size), flatseg] = True
+ onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
+ segmentation = onehot
+ else:
+ segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
+ return image, segmentation
+
+ def __getitem__(self, i):
+ img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
+ seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
+ image, segmentation = self.preprocess_image(img_path, seg_path)
+ captions = self.img_id_to_captions[self.labels["image_ids"][i]]
+ # randomly draw one of all available captions per image
+ caption = captions[np.random.randint(0, len(captions))]
+ example = {"image": image,
+ "caption": [str(caption[0])],
+ "segmentation": segmentation,
+ "img_path": img_path,
+ "seg_path": seg_path,
+ "filename_": img_path.split(os.sep)[-1]
+ }
+ return example
+
+
+class CocoImagesAndCaptionsTrain(CocoBase):
+ """returns a pair of (image, caption)"""
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
+ super().__init__(size=size,
+ dataroot="data/coco/train2017",
+ datajson="data/coco/annotations/captions_train2017.json",
+ onehot_segmentation=onehot_segmentation,
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
+
+ def get_split(self):
+ return "train"
+
+
+class CocoImagesAndCaptionsValidation(CocoBase):
+ """returns a pair of (image, caption)"""
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
+ given_files=None):
+ super().__init__(size=size,
+ dataroot="data/coco/val2017",
+ datajson="data/coco/annotations/captions_val2017.json",
+ onehot_segmentation=onehot_segmentation,
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
+ given_files=given_files)
+
+ def get_split(self):
+ return "validation"
diff --git a/repositories/taming-transformers/taming/data/conditional_builder/objects_bbox.py b/repositories/taming-transformers/taming/data/conditional_builder/objects_bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..15881e76b7ab2a914df8f2dfe08ae4f0c6c511b5
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/conditional_builder/objects_bbox.py
@@ -0,0 +1,60 @@
+from itertools import cycle
+from typing import List, Tuple, Callable, Optional
+
+from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
+from more_itertools.recipes import grouper
+from taming.data.image_transforms import convert_pil_to_tensor
+from torch import LongTensor, Tensor
+
+from taming.data.helper_types import BoundingBox, Annotation
+from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
+from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
+ pad_list, get_plot_font_size, absolute_bbox
+
+
+class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
+ @property
+ def object_descriptor_length(self) -> int:
+ return 3
+
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
+ object_triples = [
+ (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
+ for ann in annotations
+ ]
+ empty_triple = (self.none, self.none, self.none)
+ object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
+ return object_triples
+
+ def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
+ conditional_list = conditional.tolist()
+ crop_coordinates = None
+ if self.encode_crop:
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
+ conditional_list = conditional_list[:-2]
+ object_triples = grouper(conditional_list, 3)
+ assert conditional.shape[0] == self.embedding_dim
+ return [
+ (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
+ for object_triple in object_triples if object_triple[0] != self.none
+ ], crop_coordinates
+
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
+ plot = pil_image.new('RGB', figure_size, WHITE)
+ draw = pil_img_draw.Draw(plot)
+ font = ImageFont.truetype(
+ "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
+ size=get_plot_font_size(font_size, figure_size)
+ )
+ width, height = plot.size
+ description, crop_coordinates = self.inverse_build(conditional)
+ for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
+ annotation = self.representation_to_annotation(representation)
+ class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
+ bbox = absolute_bbox(bbox, width, height)
+ draw.rectangle(bbox, outline=color, width=line_width)
+ draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
+ if crop_coordinates is not None:
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
diff --git a/repositories/taming-transformers/taming/data/conditional_builder/objects_center_points.py b/repositories/taming-transformers/taming/data/conditional_builder/objects_center_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a480329cc47fb38a7b8729d424e092b77d40749
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/conditional_builder/objects_center_points.py
@@ -0,0 +1,168 @@
+import math
+import random
+import warnings
+from itertools import cycle
+from typing import List, Optional, Tuple, Callable
+
+from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
+from more_itertools.recipes import grouper
+from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
+ additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
+ absolute_bbox, rescale_annotations
+from taming.data.helper_types import BoundingBox, Annotation
+from taming.data.image_transforms import convert_pil_to_tensor
+from torch import LongTensor, Tensor
+
+
+class ObjectsCenterPointsConditionalBuilder:
+ def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
+ use_group_parameter: bool, use_additional_parameters: bool):
+ self.no_object_classes = no_object_classes
+ self.no_max_objects = no_max_objects
+ self.no_tokens = no_tokens
+ self.encode_crop = encode_crop
+ self.no_sections = int(math.sqrt(self.no_tokens))
+ self.use_group_parameter = use_group_parameter
+ self.use_additional_parameters = use_additional_parameters
+
+ @property
+ def none(self) -> int:
+ return self.no_tokens - 1
+
+ @property
+ def object_descriptor_length(self) -> int:
+ return 2
+
+ @property
+ def embedding_dim(self) -> int:
+ extra_length = 2 if self.encode_crop else 0
+ return self.no_max_objects * self.object_descriptor_length + extra_length
+
+ def tokenize_coordinates(self, x: float, y: float) -> int:
+ """
+ Express 2d coordinates with one number.
+ Example: assume self.no_tokens = 16, then no_sections = 4:
+ 0 0 0 0
+ 0 0 # 0
+ 0 0 0 0
+ 0 0 0 x
+ Then the # position corresponds to token 6, the x position to token 15.
+ @param x: float in [0, 1]
+ @param y: float in [0, 1]
+ @return: discrete tokenized coordinate
+ """
+ x_discrete = int(round(x * (self.no_sections - 1)))
+ y_discrete = int(round(y * (self.no_sections - 1)))
+ return y_discrete * self.no_sections + x_discrete
+
+ def coordinates_from_token(self, token: int) -> (float, float):
+ x = token % self.no_sections
+ y = token // self.no_sections
+ return x / (self.no_sections - 1), y / (self.no_sections - 1)
+
+ def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
+ x0, y0 = self.coordinates_from_token(token1)
+ x1, y1 = self.coordinates_from_token(token2)
+ return x0, y0, x1 - x0, y1 - y0
+
+ def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
+ return self.tokenize_coordinates(bbox[0], bbox[1]), \
+ self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
+
+ def inverse_build(self, conditional: LongTensor) \
+ -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
+ conditional_list = conditional.tolist()
+ crop_coordinates = None
+ if self.encode_crop:
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
+ conditional_list = conditional_list[:-2]
+ table_of_content = grouper(conditional_list, self.object_descriptor_length)
+ assert conditional.shape[0] == self.embedding_dim
+ return [
+ (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
+ for object_tuple in table_of_content if object_tuple[0] != self.none
+ ], crop_coordinates
+
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
+ plot = pil_image.new('RGB', figure_size, WHITE)
+ draw = pil_img_draw.Draw(plot)
+ circle_size = get_circle_size(figure_size)
+ font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
+ size=get_plot_font_size(font_size, figure_size))
+ width, height = plot.size
+ description, crop_coordinates = self.inverse_build(conditional)
+ for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
+ x_abs, y_abs = x * width, y * height
+ ann = self.representation_to_annotation(representation)
+ label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
+ ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
+ draw.ellipse(ellipse_bbox, fill=color, width=0)
+ draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
+ if crop_coordinates is not None:
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
+
+ def object_representation(self, annotation: Annotation) -> int:
+ modifier = 0
+ if self.use_group_parameter:
+ modifier |= 1 * (annotation.is_group_of is True)
+ if self.use_additional_parameters:
+ modifier |= 2 * (annotation.is_occluded is True)
+ modifier |= 4 * (annotation.is_depiction is True)
+ modifier |= 8 * (annotation.is_inside is True)
+ return annotation.category_no + self.no_object_classes * modifier
+
+ def representation_to_annotation(self, representation: int) -> Annotation:
+ category_no = representation % self.no_object_classes
+ modifier = representation // self.no_object_classes
+ # noinspection PyTypeChecker
+ return Annotation(
+ area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
+ category_no=category_no,
+ is_group_of=bool((modifier & 1) * self.use_group_parameter),
+ is_occluded=bool((modifier & 2) * self.use_additional_parameters),
+ is_depiction=bool((modifier & 4) * self.use_additional_parameters),
+ is_inside=bool((modifier & 8) * self.use_additional_parameters)
+ )
+
+ def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
+ return list(self.token_pair_from_bbox(crop_coordinates))
+
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
+ object_tuples = [
+ (self.object_representation(a),
+ self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
+ for a in annotations
+ ]
+ empty_tuple = (self.none, self.none)
+ object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
+ return object_tuples
+
+ def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
+ -> LongTensor:
+ if len(annotations) == 0:
+ warnings.warn('Did not receive any annotations.')
+ if len(annotations) > self.no_max_objects:
+ warnings.warn('Received more annotations than allowed.')
+ annotations = annotations[:self.no_max_objects]
+
+ if not crop_coordinates:
+ crop_coordinates = FULL_CROP
+
+ random.shuffle(annotations)
+ annotations = filter_annotations(annotations, crop_coordinates)
+ if self.encode_crop:
+ annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
+ if horizontal_flip:
+ crop_coordinates = horizontally_flip_bbox(crop_coordinates)
+ extra = self._crop_encoder(crop_coordinates)
+ else:
+ annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
+ extra = []
+
+ object_tuples = self._make_object_descriptors(annotations)
+ flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
+ assert len(flattened) == self.embedding_dim
+ assert all(0 <= value < self.no_tokens for value in flattened)
+ return LongTensor(flattened)
diff --git a/repositories/taming-transformers/taming/data/conditional_builder/utils.py b/repositories/taming-transformers/taming/data/conditional_builder/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0ee175f2e05a80dbc71c22acbecb22dddadbb42
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/conditional_builder/utils.py
@@ -0,0 +1,105 @@
+import importlib
+from typing import List, Any, Tuple, Optional
+
+from taming.data.helper_types import BoundingBox, Annotation
+
+# source: seaborn, color palette tab10
+COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
+ (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
+BLACK = (0, 0, 0)
+GRAY_75 = (63, 63, 63)
+GRAY_50 = (127, 127, 127)
+GRAY_25 = (191, 191, 191)
+WHITE = (255, 255, 255)
+FULL_CROP = (0., 0., 1., 1.)
+
+
+def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
+ """
+ Give intersection area of two rectangles.
+ @param rectangle1: (x0, y0, w, h) of first rectangle
+ @param rectangle2: (x0, y0, w, h) of second rectangle
+ """
+ rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
+ rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
+ x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
+ y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
+ return x_overlap * y_overlap
+
+
+def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
+ return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
+
+
+def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
+ bbox = relative_bbox
+ bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
+ return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
+
+
+def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
+ return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
+
+
+def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
+ List[Annotation]:
+ def clamp(x: float):
+ return max(min(x, 1.), 0.)
+
+ def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+ if flip:
+ x0 = 1 - (x0 + w)
+ return x0, y0, w, h
+
+ return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
+
+
+def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
+ return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
+
+
+def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
+ sl = slice(1) if short else slice(None)
+ string = ''
+ if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
+ return string
+ if annotation.is_group_of:
+ string += 'group'[sl] + ','
+ if annotation.is_occluded:
+ string += 'occluded'[sl] + ','
+ if annotation.is_depiction:
+ string += 'depiction'[sl] + ','
+ if annotation.is_inside:
+ string += 'inside'[sl]
+ return '(' + string.strip(",") + ')'
+
+
+def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
+ if font_size is None:
+ font_size = 10
+ if max(figure_size) >= 256:
+ font_size = 12
+ if max(figure_size) >= 512:
+ font_size = 15
+ return font_size
+
+
+def get_circle_size(figure_size: Tuple[int, int]) -> int:
+ circle_size = 2
+ if max(figure_size) >= 256:
+ circle_size = 3
+ if max(figure_size) >= 512:
+ circle_size = 4
+ return circle_size
+
+
+def load_object_from_string(object_string: str) -> Any:
+ """
+ Source: https://stackoverflow.com/a/10773699
+ """
+ module_name, class_name = object_string.rsplit(".", 1)
+ return getattr(importlib.import_module(module_name), class_name)
diff --git a/repositories/taming-transformers/taming/data/custom.py b/repositories/taming-transformers/taming/data/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..33f302a4b55ba1e8ec282ec3292b6263c06dfb91
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/custom.py
@@ -0,0 +1,38 @@
+import os
+import numpy as np
+import albumentations
+from torch.utils.data import Dataset
+
+from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
+
+
+class CustomBase(Dataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.data = None
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ example = self.data[i]
+ return example
+
+
+
+class CustomTrain(CustomBase):
+ def __init__(self, size, training_images_list_file):
+ super().__init__()
+ with open(training_images_list_file, "r") as f:
+ paths = f.read().splitlines()
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+
+
+class CustomTest(CustomBase):
+ def __init__(self, size, test_images_list_file):
+ super().__init__()
+ with open(test_images_list_file, "r") as f:
+ paths = f.read().splitlines()
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+
+
diff --git a/repositories/taming-transformers/taming/data/faceshq.py b/repositories/taming-transformers/taming/data/faceshq.py
new file mode 100644
index 0000000000000000000000000000000000000000..6912d04b66a6d464c1078e4b51d5da290f5e767e
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/faceshq.py
@@ -0,0 +1,134 @@
+import os
+import numpy as np
+import albumentations
+from torch.utils.data import Dataset
+
+from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
+
+
+class FacesBase(Dataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.data = None
+ self.keys = None
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ example = self.data[i]
+ ex = {}
+ if self.keys is not None:
+ for k in self.keys:
+ ex[k] = example[k]
+ else:
+ ex = example
+ return ex
+
+
+class CelebAHQTrain(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/celebahq"
+ with open("data/celebahqtrain.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class CelebAHQValidation(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/celebahq"
+ with open("data/celebahqvalidation.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FFHQTrain(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/ffhq"
+ with open("data/ffhqtrain.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FFHQValidation(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/ffhq"
+ with open("data/ffhqvalidation.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FacesHQTrain(Dataset):
+ # CelebAHQ [0] + FFHQ [1]
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
+ d1 = CelebAHQTrain(size=size, keys=keys)
+ d2 = FFHQTrain(size=size, keys=keys)
+ self.data = ConcatDatasetWithIndex([d1, d2])
+ self.coord = coord
+ if crop_size is not None:
+ self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
+ if self.coord:
+ self.cropper = albumentations.Compose([self.cropper],
+ additional_targets={"coord": "image"})
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ ex, y = self.data[i]
+ if hasattr(self, "cropper"):
+ if not self.coord:
+ out = self.cropper(image=ex["image"])
+ ex["image"] = out["image"]
+ else:
+ h,w,_ = ex["image"].shape
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
+ out = self.cropper(image=ex["image"], coord=coord)
+ ex["image"] = out["image"]
+ ex["coord"] = out["coord"]
+ ex["class"] = y
+ return ex
+
+
+class FacesHQValidation(Dataset):
+ # CelebAHQ [0] + FFHQ [1]
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
+ d1 = CelebAHQValidation(size=size, keys=keys)
+ d2 = FFHQValidation(size=size, keys=keys)
+ self.data = ConcatDatasetWithIndex([d1, d2])
+ self.coord = coord
+ if crop_size is not None:
+ self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
+ if self.coord:
+ self.cropper = albumentations.Compose([self.cropper],
+ additional_targets={"coord": "image"})
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ ex, y = self.data[i]
+ if hasattr(self, "cropper"):
+ if not self.coord:
+ out = self.cropper(image=ex["image"])
+ ex["image"] = out["image"]
+ else:
+ h,w,_ = ex["image"].shape
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
+ out = self.cropper(image=ex["image"], coord=coord)
+ ex["image"] = out["image"]
+ ex["coord"] = out["coord"]
+ ex["class"] = y
+ return ex
diff --git a/repositories/taming-transformers/taming/data/helper_types.py b/repositories/taming-transformers/taming/data/helper_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb51e301da08602cfead5961c4f7e1d89f6aba79
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/helper_types.py
@@ -0,0 +1,49 @@
+from typing import Dict, Tuple, Optional, NamedTuple, Union
+from PIL.Image import Image as pil_image
+from torch import Tensor
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+Image = Union[Tensor, pil_image]
+BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
+CropMethodType = Literal['none', 'random', 'center', 'random-2d']
+SplitType = Literal['train', 'validation', 'test']
+
+
+class ImageDescription(NamedTuple):
+ id: int
+ file_name: str
+ original_size: Tuple[int, int] # w, h
+ url: Optional[str] = None
+ license: Optional[int] = None
+ coco_url: Optional[str] = None
+ date_captured: Optional[str] = None
+ flickr_url: Optional[str] = None
+ flickr_id: Optional[str] = None
+ coco_id: Optional[str] = None
+
+
+class Category(NamedTuple):
+ id: str
+ super_category: Optional[str]
+ name: str
+
+
+class Annotation(NamedTuple):
+ area: float
+ image_id: str
+ bbox: BoundingBox
+ category_no: int
+ category_id: str
+ id: Optional[int] = None
+ source: Optional[str] = None
+ confidence: Optional[float] = None
+ is_group_of: Optional[bool] = None
+ is_truncated: Optional[bool] = None
+ is_occluded: Optional[bool] = None
+ is_depiction: Optional[bool] = None
+ is_inside: Optional[bool] = None
+ segmentation: Optional[Dict] = None
diff --git a/repositories/taming-transformers/taming/data/image_transforms.py b/repositories/taming-transformers/taming/data/image_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..657ac332174e0ac72f68315271ffbd757b771a0f
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/image_transforms.py
@@ -0,0 +1,132 @@
+import random
+import warnings
+from typing import Union
+
+import torch
+from torch import Tensor
+from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
+from torchvision.transforms.functional import _get_image_size as get_image_size
+
+from taming.data.helper_types import BoundingBox, Image
+
+pil_to_tensor = PILToTensor()
+
+
+def convert_pil_to_tensor(image: Image) -> Tensor:
+ with warnings.catch_warnings():
+ # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
+ warnings.simplefilter("ignore")
+ return pil_to_tensor(image)
+
+
+class RandomCrop1dReturnCoordinates(RandomCrop):
+ def forward(self, img: Image) -> (BoundingBox, Image):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+
+ Based on:
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
+ """
+ if self.padding is not None:
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
+
+ width, height = get_image_size(img)
+ # pad the width if needed
+ if self.pad_if_needed and width < self.size[1]:
+ padding = [self.size[1] - width, 0]
+ img = F.pad(img, padding, self.fill, self.padding_mode)
+ # pad the height if needed
+ if self.pad_if_needed and height < self.size[0]:
+ padding = [0, self.size[0] - height]
+ img = F.pad(img, padding, self.fill, self.padding_mode)
+
+ i, j, h, w = self.get_params(img, self.size)
+ bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
+ return bbox, F.crop(img, i, j, h, w)
+
+
+class Random2dCropReturnCoordinates(torch.nn.Module):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+
+ Based on:
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
+ """
+
+ def __init__(self, min_size: int):
+ super().__init__()
+ self.min_size = min_size
+
+ def forward(self, img: Image) -> (BoundingBox, Image):
+ width, height = get_image_size(img)
+ max_size = min(width, height)
+ if max_size <= self.min_size:
+ size = max_size
+ else:
+ size = random.randint(self.min_size, max_size)
+ top = random.randint(0, height - size)
+ left = random.randint(0, width - size)
+ bbox = left / width, top / height, size / width, size / height
+ return bbox, F.crop(img, top, left, size, size)
+
+
+class CenterCropReturnCoordinates(CenterCrop):
+ @staticmethod
+ def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
+ if width > height:
+ w = height / width
+ h = 1.0
+ x0 = 0.5 - w / 2
+ y0 = 0.
+ else:
+ w = 1.0
+ h = width / height
+ x0 = 0.
+ y0 = 0.5 - h / 2
+ return x0, y0, w, h
+
+ def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+ Based on:
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
+ """
+ width, height = get_image_size(img)
+ return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
+
+
+class RandomHorizontalFlipReturn(RandomHorizontalFlip):
+ def forward(self, img: Image) -> (bool, Image):
+ """
+ Additionally to flipping, returns a boolean whether it was flipped or not.
+ Args:
+ img (PIL Image or Tensor): Image to be flipped.
+
+ Returns:
+ flipped: whether the image was flipped or not
+ PIL Image or Tensor: Randomly flipped image.
+
+ Based on:
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
+ """
+ if torch.rand(1) < self.p:
+ return True, F.hflip(img)
+ return False, img
diff --git a/repositories/taming-transformers/taming/data/imagenet.py b/repositories/taming-transformers/taming/data/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a02ec44ba4af9e993f58c91fa43482a4ecbe54c
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/imagenet.py
@@ -0,0 +1,558 @@
+import os, tarfile, glob, shutil
+import yaml
+import numpy as np
+from tqdm import tqdm
+from PIL import Image
+import albumentations
+from omegaconf import OmegaConf
+from torch.utils.data import Dataset
+
+from taming.data.base import ImagePaths
+from taming.util import download, retrieve
+import taming.data.utils as bdu
+
+
+def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
+ synsets = []
+ with open(path_to_yaml) as f:
+ di2s = yaml.load(f)
+ for idx in indices:
+ synsets.append(str(di2s[idx]))
+ print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
+ return synsets
+
+
+def str_to_indices(string):
+ """Expects a string in the format '32-123, 256, 280-321'"""
+ assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
+ subs = string.split(",")
+ indices = []
+ for sub in subs:
+ subsubs = sub.split("-")
+ assert len(subsubs) > 0
+ if len(subsubs) == 1:
+ indices.append(int(subsubs[0]))
+ else:
+ rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
+ indices.extend(rang)
+ return sorted(indices)
+
+
+class ImageNetBase(Dataset):
+ def __init__(self, config=None):
+ self.config = config or OmegaConf.create()
+ if not type(self.config)==dict:
+ self.config = OmegaConf.to_container(self.config)
+ self._prepare()
+ self._prepare_synset_to_human()
+ self._prepare_idx_to_synset()
+ self._load()
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ return self.data[i]
+
+ def _prepare(self):
+ raise NotImplementedError()
+
+ def _filter_relpaths(self, relpaths):
+ ignore = set([
+ "n06596364_9591.JPEG",
+ ])
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
+ if "sub_indices" in self.config:
+ indices = str_to_indices(self.config["sub_indices"])
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
+ files = []
+ for rpath in relpaths:
+ syn = rpath.split("/")[0]
+ if syn in synsets:
+ files.append(rpath)
+ return files
+ else:
+ return relpaths
+
+ def _prepare_synset_to_human(self):
+ SIZE = 2655750
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
+ if (not os.path.exists(self.human_dict) or
+ not os.path.getsize(self.human_dict)==SIZE):
+ download(URL, self.human_dict)
+
+ def _prepare_idx_to_synset(self):
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
+ if (not os.path.exists(self.idx2syn)):
+ download(URL, self.idx2syn)
+
+ def _load(self):
+ with open(self.txt_filelist, "r") as f:
+ self.relpaths = f.read().splitlines()
+ l1 = len(self.relpaths)
+ self.relpaths = self._filter_relpaths(self.relpaths)
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
+
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
+
+ unique_synsets = np.unique(self.synsets)
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
+ self.class_labels = [class_dict[s] for s in self.synsets]
+
+ with open(self.human_dict, "r") as f:
+ human_dict = f.read().splitlines()
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
+
+ self.human_labels = [human_dict[s] for s in self.synsets]
+
+ labels = {
+ "relpath": np.array(self.relpaths),
+ "synsets": np.array(self.synsets),
+ "class_label": np.array(self.class_labels),
+ "human_label": np.array(self.human_labels),
+ }
+ self.data = ImagePaths(self.abspaths,
+ labels=labels,
+ size=retrieve(self.config, "size", default=0),
+ random_crop=self.random_crop)
+
+
+class ImageNetTrain(ImageNetBase):
+ NAME = "ILSVRC2012_train"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
+ FILES = [
+ "ILSVRC2012_img_train.tar",
+ ]
+ SIZES = [
+ 147897477120,
+ ]
+
+ def _prepare(self):
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
+ default=True)
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 1281167
+ if not bdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ print("Extracting sub-tars.")
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
+ for subpath in tqdm(subpaths):
+ subdir = subpath[:-len(".tar")]
+ os.makedirs(subdir, exist_ok=True)
+ with tarfile.open(subpath, "r:") as tar:
+ tar.extractall(path=subdir)
+
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ bdu.mark_prepared(self.root)
+
+
+class ImageNetValidation(ImageNetBase):
+ NAME = "ILSVRC2012_validation"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
+ FILES = [
+ "ILSVRC2012_img_val.tar",
+ "validation_synset.txt",
+ ]
+ SIZES = [
+ 6744924160,
+ 1950000,
+ ]
+
+ def _prepare(self):
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
+ default=False)
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 50000
+ if not bdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ vspath = os.path.join(self.root, self.FILES[1])
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
+ download(self.VS_URL, vspath)
+
+ with open(vspath, "r") as f:
+ synset_dict = f.read().splitlines()
+ synset_dict = dict(line.split() for line in synset_dict)
+
+ print("Reorganizing into synset folders")
+ synsets = np.unique(list(synset_dict.values()))
+ for s in synsets:
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
+ for k, v in synset_dict.items():
+ src = os.path.join(datadir, k)
+ dst = os.path.join(datadir, v)
+ shutil.move(src, dst)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ bdu.mark_prepared(self.root)
+
+
+def get_preprocessor(size=None, random_crop=False, additional_targets=None,
+ crop_size=None):
+ if size is not None and size > 0:
+ transforms = list()
+ rescaler = albumentations.SmallestMaxSize(max_size = size)
+ transforms.append(rescaler)
+ if not random_crop:
+ cropper = albumentations.CenterCrop(height=size,width=size)
+ transforms.append(cropper)
+ else:
+ cropper = albumentations.RandomCrop(height=size,width=size)
+ transforms.append(cropper)
+ flipper = albumentations.HorizontalFlip()
+ transforms.append(flipper)
+ preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ elif crop_size is not None and crop_size > 0:
+ if not random_crop:
+ cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
+ else:
+ cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
+ transforms = [cropper]
+ preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ else:
+ preprocessor = lambda **kwargs: kwargs
+ return preprocessor
+
+
+def rgba_to_depth(x):
+ assert x.dtype == np.uint8
+ assert len(x.shape) == 3 and x.shape[2] == 4
+ y = x.copy()
+ y.dtype = np.float32
+ y = y.reshape(x.shape[:2])
+ return np.ascontiguousarray(y)
+
+
+class BaseWithDepth(Dataset):
+ DEFAULT_DEPTH_ROOT="data/imagenet_depth"
+
+ def __init__(self, config=None, size=None, random_crop=False,
+ crop_size=None, root=None):
+ self.config = config
+ self.base_dset = self.get_base_dset()
+ self.preprocessor = get_preprocessor(
+ size=size,
+ crop_size=crop_size,
+ random_crop=random_crop,
+ additional_targets={"depth": "image"})
+ self.crop_size = crop_size
+ if self.crop_size is not None:
+ self.rescaler = albumentations.Compose(
+ [albumentations.SmallestMaxSize(max_size = self.crop_size)],
+ additional_targets={"depth": "image"})
+ if root is not None:
+ self.DEFAULT_DEPTH_ROOT = root
+
+ def __len__(self):
+ return len(self.base_dset)
+
+ def preprocess_depth(self, path):
+ rgba = np.array(Image.open(path))
+ depth = rgba_to_depth(rgba)
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
+ depth = 2.0*depth-1.0
+ return depth
+
+ def __getitem__(self, i):
+ e = self.base_dset[i]
+ e["depth"] = self.preprocess_depth(self.get_depth_path(e))
+ # up if necessary
+ h,w,c = e["image"].shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ out = self.rescaler(image=e["image"], depth=e["depth"])
+ e["image"] = out["image"]
+ e["depth"] = out["depth"]
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
+ e["image"] = transformed["image"]
+ e["depth"] = transformed["depth"]
+ return e
+
+
+class ImageNetTrainWithDepth(BaseWithDepth):
+ # default to random_crop=True
+ def __init__(self, random_crop=True, sub_indices=None, **kwargs):
+ self.sub_indices = sub_indices
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base_dset(self):
+ if self.sub_indices is None:
+ return ImageNetTrain()
+ else:
+ return ImageNetTrain({"sub_indices": self.sub_indices})
+
+ def get_depth_path(self, e):
+ fid = os.path.splitext(e["relpath"])[0]+".png"
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
+ return fid
+
+
+class ImageNetValidationWithDepth(BaseWithDepth):
+ def __init__(self, sub_indices=None, **kwargs):
+ self.sub_indices = sub_indices
+ super().__init__(**kwargs)
+
+ def get_base_dset(self):
+ if self.sub_indices is None:
+ return ImageNetValidation()
+ else:
+ return ImageNetValidation({"sub_indices": self.sub_indices})
+
+ def get_depth_path(self, e):
+ fid = os.path.splitext(e["relpath"])[0]+".png"
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
+ return fid
+
+
+class RINTrainWithDepth(ImageNetTrainWithDepth):
+ def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ sub_indices=sub_indices, crop_size=crop_size)
+
+
+class RINValidationWithDepth(ImageNetValidationWithDepth):
+ def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ sub_indices=sub_indices, crop_size=crop_size)
+
+
+class DRINExamples(Dataset):
+ def __init__(self):
+ self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
+ with open("data/drin_examples.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ self.image_paths = [os.path.join("data/drin_images",
+ relpath) for relpath in relpaths]
+ self.depth_paths = [os.path.join("data/drin_depth",
+ relpath.replace(".JPEG", ".png")) for relpath in relpaths]
+
+ def __len__(self):
+ return len(self.image_paths)
+
+ def preprocess_image(self, image_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
+
+ def preprocess_depth(self, path):
+ rgba = np.array(Image.open(path))
+ depth = rgba_to_depth(rgba)
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
+ depth = 2.0*depth-1.0
+ return depth
+
+ def __getitem__(self, i):
+ e = dict()
+ e["image"] = self.preprocess_image(self.image_paths[i])
+ e["depth"] = self.preprocess_depth(self.depth_paths[i])
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
+ e["image"] = transformed["image"]
+ e["depth"] = transformed["depth"]
+ return e
+
+
+def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
+ if factor is None or factor==1:
+ return x
+
+ dtype = x.dtype
+ assert dtype in [np.float32, np.float64]
+ assert x.min() >= -1
+ assert x.max() <= 1
+
+ keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
+ "bicubic": Image.BICUBIC}[keepmode]
+
+ lr = (x+1.0)*127.5
+ lr = lr.clip(0,255).astype(np.uint8)
+ lr = Image.fromarray(lr)
+
+ h, w, _ = x.shape
+ nh = h//factor
+ nw = w//factor
+ assert nh > 0 and nw > 0, (nh, nw)
+
+ lr = lr.resize((nw,nh), Image.BICUBIC)
+ if keepshapes:
+ lr = lr.resize((w,h), keepmode)
+ lr = np.array(lr)/127.5-1.0
+ lr = lr.astype(dtype)
+
+ return lr
+
+
+class ImageNetScale(Dataset):
+ def __init__(self, size=None, crop_size=None, random_crop=False,
+ up_factor=None, hr_factor=None, keep_mode="bicubic"):
+ self.base = self.get_base()
+
+ self.size = size
+ self.crop_size = crop_size if crop_size is not None else self.size
+ self.random_crop = random_crop
+ self.up_factor = up_factor
+ self.hr_factor = hr_factor
+ self.keep_mode = keep_mode
+
+ transforms = list()
+
+ if self.size is not None and self.size > 0:
+ rescaler = albumentations.SmallestMaxSize(max_size = self.size)
+ self.rescaler = rescaler
+ transforms.append(rescaler)
+
+ if self.crop_size is not None and self.crop_size > 0:
+ if len(transforms) == 0:
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
+
+ if not self.random_crop:
+ cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
+ else:
+ cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
+ transforms.append(cropper)
+
+ if len(transforms) > 0:
+ if self.up_factor is not None:
+ additional_targets = {"lr": "image"}
+ else:
+ additional_targets = None
+ self.preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ else:
+ self.preprocessor = lambda **kwargs: kwargs
+
+ def __len__(self):
+ return len(self.base)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = example["image"]
+ # adjust resolution
+ image = imscale(image, self.hr_factor, keepshapes=False)
+ h,w,c = image.shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ image = self.rescaler(image=image)["image"]
+ if self.up_factor is None:
+ image = self.preprocessor(image=image)["image"]
+ example["image"] = image
+ else:
+ lr = imscale(image, self.up_factor, keepshapes=True,
+ keepmode=self.keep_mode)
+
+ out = self.preprocessor(image=image, lr=lr)
+ example["image"] = out["image"]
+ example["lr"] = out["lr"]
+
+ return example
+
+class ImageNetScaleTrain(ImageNetScale):
+ def __init__(self, random_crop=True, **kwargs):
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base(self):
+ return ImageNetTrain()
+
+class ImageNetScaleValidation(ImageNetScale):
+ def get_base(self):
+ return ImageNetValidation()
+
+
+from skimage.feature import canny
+from skimage.color import rgb2gray
+
+
+class ImageNetEdges(ImageNetScale):
+ def __init__(self, up_factor=1, **kwargs):
+ super().__init__(up_factor=1, **kwargs)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = example["image"]
+ h,w,c = image.shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ image = self.rescaler(image=image)["image"]
+
+ lr = canny(rgb2gray(image), sigma=2)
+ lr = lr.astype(np.float32)
+ lr = lr[:,:,None][:,:,[0,0,0]]
+
+ out = self.preprocessor(image=image, lr=lr)
+ example["image"] = out["image"]
+ example["lr"] = out["lr"]
+
+ return example
+
+
+class ImageNetEdgesTrain(ImageNetEdges):
+ def __init__(self, random_crop=True, **kwargs):
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base(self):
+ return ImageNetTrain()
+
+class ImageNetEdgesValidation(ImageNetEdges):
+ def get_base(self):
+ return ImageNetValidation()
diff --git a/repositories/taming-transformers/taming/data/open_images_helper.py b/repositories/taming-transformers/taming/data/open_images_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..8feb7c6e705fc165d2983303192aaa88f579b243
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/open_images_helper.py
@@ -0,0 +1,379 @@
+open_images_unify_categories_for_coco = {
+ '/m/03bt1vf': '/m/01g317',
+ '/m/04yx4': '/m/01g317',
+ '/m/05r655': '/m/01g317',
+ '/m/01bl7v': '/m/01g317',
+ '/m/0cnyhnx': '/m/01xq0k1',
+ '/m/01226z': '/m/018xm',
+ '/m/05ctyq': '/m/018xm',
+ '/m/058qzx': '/m/04ctx',
+ '/m/06pcq': '/m/0l515',
+ '/m/03m3pdh': '/m/02crq1',
+ '/m/046dlr': '/m/01x3z',
+ '/m/0h8mzrc': '/m/01x3z',
+}
+
+
+top_300_classes_plus_coco_compatibility = [
+ ('Man', 1060962),
+ ('Clothing', 986610),
+ ('Tree', 748162),
+ ('Woman', 611896),
+ ('Person', 610294),
+ ('Human face', 442948),
+ ('Girl', 175399),
+ ('Building', 162147),
+ ('Car', 159135),
+ ('Plant', 155704),
+ ('Human body', 137073),
+ ('Flower', 133128),
+ ('Window', 127485),
+ ('Human arm', 118380),
+ ('House', 114365),
+ ('Wheel', 111684),
+ ('Suit', 99054),
+ ('Human hair', 98089),
+ ('Human head', 92763),
+ ('Chair', 88624),
+ ('Boy', 79849),
+ ('Table', 73699),
+ ('Jeans', 57200),
+ ('Tire', 55725),
+ ('Skyscraper', 53321),
+ ('Food', 52400),
+ ('Footwear', 50335),
+ ('Dress', 50236),
+ ('Human leg', 47124),
+ ('Toy', 46636),
+ ('Tower', 45605),
+ ('Boat', 43486),
+ ('Land vehicle', 40541),
+ ('Bicycle wheel', 34646),
+ ('Palm tree', 33729),
+ ('Fashion accessory', 32914),
+ ('Glasses', 31940),
+ ('Bicycle', 31409),
+ ('Furniture', 30656),
+ ('Sculpture', 29643),
+ ('Bottle', 27558),
+ ('Dog', 26980),
+ ('Snack', 26796),
+ ('Human hand', 26664),
+ ('Bird', 25791),
+ ('Book', 25415),
+ ('Guitar', 24386),
+ ('Jacket', 23998),
+ ('Poster', 22192),
+ ('Dessert', 21284),
+ ('Baked goods', 20657),
+ ('Drink', 19754),
+ ('Flag', 18588),
+ ('Houseplant', 18205),
+ ('Tableware', 17613),
+ ('Airplane', 17218),
+ ('Door', 17195),
+ ('Sports uniform', 17068),
+ ('Shelf', 16865),
+ ('Drum', 16612),
+ ('Vehicle', 16542),
+ ('Microphone', 15269),
+ ('Street light', 14957),
+ ('Cat', 14879),
+ ('Fruit', 13684),
+ ('Fast food', 13536),
+ ('Animal', 12932),
+ ('Vegetable', 12534),
+ ('Train', 12358),
+ ('Horse', 11948),
+ ('Flowerpot', 11728),
+ ('Motorcycle', 11621),
+ ('Fish', 11517),
+ ('Desk', 11405),
+ ('Helmet', 10996),
+ ('Truck', 10915),
+ ('Bus', 10695),
+ ('Hat', 10532),
+ ('Auto part', 10488),
+ ('Musical instrument', 10303),
+ ('Sunglasses', 10207),
+ ('Picture frame', 10096),
+ ('Sports equipment', 10015),
+ ('Shorts', 9999),
+ ('Wine glass', 9632),
+ ('Duck', 9242),
+ ('Wine', 9032),
+ ('Rose', 8781),
+ ('Tie', 8693),
+ ('Butterfly', 8436),
+ ('Beer', 7978),
+ ('Cabinetry', 7956),
+ ('Laptop', 7907),
+ ('Insect', 7497),
+ ('Goggles', 7363),
+ ('Shirt', 7098),
+ ('Dairy Product', 7021),
+ ('Marine invertebrates', 7014),
+ ('Cattle', 7006),
+ ('Trousers', 6903),
+ ('Van', 6843),
+ ('Billboard', 6777),
+ ('Balloon', 6367),
+ ('Human nose', 6103),
+ ('Tent', 6073),
+ ('Camera', 6014),
+ ('Doll', 6002),
+ ('Coat', 5951),
+ ('Mobile phone', 5758),
+ ('Swimwear', 5729),
+ ('Strawberry', 5691),
+ ('Stairs', 5643),
+ ('Goose', 5599),
+ ('Umbrella', 5536),
+ ('Cake', 5508),
+ ('Sun hat', 5475),
+ ('Bench', 5310),
+ ('Bookcase', 5163),
+ ('Bee', 5140),
+ ('Computer monitor', 5078),
+ ('Hiking equipment', 4983),
+ ('Office building', 4981),
+ ('Coffee cup', 4748),
+ ('Curtain', 4685),
+ ('Plate', 4651),
+ ('Box', 4621),
+ ('Tomato', 4595),
+ ('Coffee table', 4529),
+ ('Office supplies', 4473),
+ ('Maple', 4416),
+ ('Muffin', 4365),
+ ('Cocktail', 4234),
+ ('Castle', 4197),
+ ('Couch', 4134),
+ ('Pumpkin', 3983),
+ ('Computer keyboard', 3960),
+ ('Human mouth', 3926),
+ ('Christmas tree', 3893),
+ ('Mushroom', 3883),
+ ('Swimming pool', 3809),
+ ('Pastry', 3799),
+ ('Lavender (Plant)', 3769),
+ ('Football helmet', 3732),
+ ('Bread', 3648),
+ ('Traffic sign', 3628),
+ ('Common sunflower', 3597),
+ ('Television', 3550),
+ ('Bed', 3525),
+ ('Cookie', 3485),
+ ('Fountain', 3484),
+ ('Paddle', 3447),
+ ('Bicycle helmet', 3429),
+ ('Porch', 3420),
+ ('Deer', 3387),
+ ('Fedora', 3339),
+ ('Canoe', 3338),
+ ('Carnivore', 3266),
+ ('Bowl', 3202),
+ ('Human eye', 3166),
+ ('Ball', 3118),
+ ('Pillow', 3077),
+ ('Salad', 3061),
+ ('Beetle', 3060),
+ ('Orange', 3050),
+ ('Drawer', 2958),
+ ('Platter', 2937),
+ ('Elephant', 2921),
+ ('Seafood', 2921),
+ ('Monkey', 2915),
+ ('Countertop', 2879),
+ ('Watercraft', 2831),
+ ('Helicopter', 2805),
+ ('Kitchen appliance', 2797),
+ ('Personal flotation device', 2781),
+ ('Swan', 2739),
+ ('Lamp', 2711),
+ ('Boot', 2695),
+ ('Bronze sculpture', 2693),
+ ('Chicken', 2677),
+ ('Taxi', 2643),
+ ('Juice', 2615),
+ ('Cowboy hat', 2604),
+ ('Apple', 2600),
+ ('Tin can', 2590),
+ ('Necklace', 2564),
+ ('Ice cream', 2560),
+ ('Human beard', 2539),
+ ('Coin', 2536),
+ ('Candle', 2515),
+ ('Cart', 2512),
+ ('High heels', 2441),
+ ('Weapon', 2433),
+ ('Handbag', 2406),
+ ('Penguin', 2396),
+ ('Rifle', 2352),
+ ('Violin', 2336),
+ ('Skull', 2304),
+ ('Lantern', 2285),
+ ('Scarf', 2269),
+ ('Saucer', 2225),
+ ('Sheep', 2215),
+ ('Vase', 2189),
+ ('Lily', 2180),
+ ('Mug', 2154),
+ ('Parrot', 2140),
+ ('Human ear', 2137),
+ ('Sandal', 2115),
+ ('Lizard', 2100),
+ ('Kitchen & dining room table', 2063),
+ ('Spider', 1977),
+ ('Coffee', 1974),
+ ('Goat', 1926),
+ ('Squirrel', 1922),
+ ('Cello', 1913),
+ ('Sushi', 1881),
+ ('Tortoise', 1876),
+ ('Pizza', 1870),
+ ('Studio couch', 1864),
+ ('Barrel', 1862),
+ ('Cosmetics', 1841),
+ ('Moths and butterflies', 1841),
+ ('Convenience store', 1817),
+ ('Watch', 1792),
+ ('Home appliance', 1786),
+ ('Harbor seal', 1780),
+ ('Luggage and bags', 1756),
+ ('Vehicle registration plate', 1754),
+ ('Shrimp', 1751),
+ ('Jellyfish', 1730),
+ ('French fries', 1723),
+ ('Egg (Food)', 1698),
+ ('Football', 1697),
+ ('Musical keyboard', 1683),
+ ('Falcon', 1674),
+ ('Candy', 1660),
+ ('Medical equipment', 1654),
+ ('Eagle', 1651),
+ ('Dinosaur', 1634),
+ ('Surfboard', 1630),
+ ('Tank', 1628),
+ ('Grape', 1624),
+ ('Lion', 1624),
+ ('Owl', 1622),
+ ('Ski', 1613),
+ ('Waste container', 1606),
+ ('Frog', 1591),
+ ('Sparrow', 1585),
+ ('Rabbit', 1581),
+ ('Pen', 1546),
+ ('Sea lion', 1537),
+ ('Spoon', 1521),
+ ('Sink', 1512),
+ ('Teddy bear', 1507),
+ ('Bull', 1495),
+ ('Sofa bed', 1490),
+ ('Dragonfly', 1479),
+ ('Brassiere', 1478),
+ ('Chest of drawers', 1472),
+ ('Aircraft', 1466),
+ ('Human foot', 1463),
+ ('Pig', 1455),
+ ('Fork', 1454),
+ ('Antelope', 1438),
+ ('Tripod', 1427),
+ ('Tool', 1424),
+ ('Cheese', 1422),
+ ('Lemon', 1397),
+ ('Hamburger', 1393),
+ ('Dolphin', 1390),
+ ('Mirror', 1390),
+ ('Marine mammal', 1387),
+ ('Giraffe', 1385),
+ ('Snake', 1368),
+ ('Gondola', 1364),
+ ('Wheelchair', 1360),
+ ('Piano', 1358),
+ ('Cupboard', 1348),
+ ('Banana', 1345),
+ ('Trumpet', 1335),
+ ('Lighthouse', 1333),
+ ('Invertebrate', 1317),
+ ('Carrot', 1268),
+ ('Sock', 1260),
+ ('Tiger', 1241),
+ ('Camel', 1224),
+ ('Parachute', 1224),
+ ('Bathroom accessory', 1223),
+ ('Earrings', 1221),
+ ('Headphones', 1218),
+ ('Skirt', 1198),
+ ('Skateboard', 1190),
+ ('Sandwich', 1148),
+ ('Saxophone', 1141),
+ ('Goldfish', 1136),
+ ('Stool', 1104),
+ ('Traffic light', 1097),
+ ('Shellfish', 1081),
+ ('Backpack', 1079),
+ ('Sea turtle', 1078),
+ ('Cucumber', 1075),
+ ('Tea', 1051),
+ ('Toilet', 1047),
+ ('Roller skates', 1040),
+ ('Mule', 1039),
+ ('Bust', 1031),
+ ('Broccoli', 1030),
+ ('Crab', 1020),
+ ('Oyster', 1019),
+ ('Cannon', 1012),
+ ('Zebra', 1012),
+ ('French horn', 1008),
+ ('Grapefruit', 998),
+ ('Whiteboard', 997),
+ ('Zucchini', 997),
+ ('Crocodile', 992),
+
+ ('Clock', 960),
+ ('Wall clock', 958),
+
+ ('Doughnut', 869),
+ ('Snail', 868),
+
+ ('Baseball glove', 859),
+
+ ('Panda', 830),
+ ('Tennis racket', 830),
+
+ ('Pear', 652),
+
+ ('Bagel', 617),
+ ('Oven', 616),
+ ('Ladybug', 615),
+ ('Shark', 615),
+ ('Polar bear', 614),
+ ('Ostrich', 609),
+
+ ('Hot dog', 473),
+ ('Microwave oven', 467),
+ ('Fire hydrant', 20),
+ ('Stop sign', 20),
+ ('Parking meter', 20),
+ ('Bear', 20),
+ ('Flying disc', 20),
+ ('Snowboard', 20),
+ ('Tennis ball', 20),
+ ('Kite', 20),
+ ('Baseball bat', 20),
+ ('Kitchen knife', 20),
+ ('Knife', 20),
+ ('Submarine sandwich', 20),
+ ('Computer mouse', 20),
+ ('Remote control', 20),
+ ('Toaster', 20),
+ ('Sink', 20),
+ ('Refrigerator', 20),
+ ('Alarm clock', 20),
+ ('Wall clock', 20),
+ ('Scissors', 20),
+ ('Hair dryer', 20),
+ ('Toothbrush', 20),
+ ('Suitcase', 20)
+]
diff --git a/repositories/taming-transformers/taming/data/sflckr.py b/repositories/taming-transformers/taming/data/sflckr.py
new file mode 100644
index 0000000000000000000000000000000000000000..91101be5953b113f1e58376af637e43f366b3dee
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/sflckr.py
@@ -0,0 +1,91 @@
+import os
+import numpy as np
+import cv2
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset
+
+
+class SegmentationBase(Dataset):
+ def __init__(self,
+ data_csv, data_root, segmentation_root,
+ size=None, random_crop=False, interpolation="bicubic",
+ n_labels=182, shift_segmentation=False,
+ ):
+ self.n_labels = n_labels
+ self.shift_segmentation = shift_segmentation
+ self.data_csv = data_csv
+ self.data_root = data_root
+ self.segmentation_root = segmentation_root
+ with open(self.data_csv, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, l)
+ for l in self.image_paths],
+ "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
+ for l in self.image_paths]
+ }
+
+ size = None if size is not None and size<=0 else size
+ self.size = size
+ if self.size is not None:
+ self.interpolation = interpolation
+ self.interpolation = {
+ "nearest": cv2.INTER_NEAREST,
+ "bilinear": cv2.INTER_LINEAR,
+ "bicubic": cv2.INTER_CUBIC,
+ "area": cv2.INTER_AREA,
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=self.interpolation)
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=cv2.INTER_NEAREST)
+ self.center_crop = not random_crop
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
+ self.preprocessor = self.cropper
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ if self.size is not None:
+ image = self.image_rescaler(image=image)["image"]
+ segmentation = Image.open(example["segmentation_path_"])
+ assert segmentation.mode == "L", segmentation.mode
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.shift_segmentation:
+ # used to support segmentations containing unlabeled==255 label
+ segmentation = segmentation+1
+ if self.size is not None:
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
+ if self.size is not None:
+ processed = self.preprocessor(image=image,
+ mask=segmentation
+ )
+ else:
+ processed = {"image": image,
+ "mask": segmentation
+ }
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
+ segmentation = processed["mask"]
+ onehot = np.eye(self.n_labels)[segmentation]
+ example["segmentation"] = onehot
+ return example
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/sflckr_examples.txt",
+ data_root="data/sflckr_images",
+ segmentation_root="data/sflckr_segmentations",
+ size=size, random_crop=random_crop, interpolation=interpolation)
diff --git a/repositories/taming-transformers/taming/data/utils.py b/repositories/taming-transformers/taming/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b3c3d53cd2b6c72b481b59834cf809d3735b394
--- /dev/null
+++ b/repositories/taming-transformers/taming/data/utils.py
@@ -0,0 +1,169 @@
+import collections
+import os
+import tarfile
+import urllib
+import zipfile
+from pathlib import Path
+
+import numpy as np
+import torch
+from taming.data.helper_types import Annotation
+from torch._six import string_classes
+from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
+from tqdm import tqdm
+
+
+def unpack(path):
+ if path.endswith("tar.gz"):
+ with tarfile.open(path, "r:gz") as tar:
+ tar.extractall(path=os.path.split(path)[0])
+ elif path.endswith("tar"):
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=os.path.split(path)[0])
+ elif path.endswith("zip"):
+ with zipfile.ZipFile(path, "r") as f:
+ f.extractall(path=os.path.split(path)[0])
+ else:
+ raise NotImplementedError(
+ "Unknown file extension: {}".format(os.path.splitext(path)[1])
+ )
+
+
+def reporthook(bar):
+ """tqdm progress bar for downloads."""
+
+ def hook(b=1, bsize=1, tsize=None):
+ if tsize is not None:
+ bar.total = tsize
+ bar.update(b * bsize - bar.n)
+
+ return hook
+
+
+def get_root(name):
+ base = "data/"
+ root = os.path.join(base, name)
+ os.makedirs(root, exist_ok=True)
+ return root
+
+
+def is_prepared(root):
+ return Path(root).joinpath(".ready").exists()
+
+
+def mark_prepared(root):
+ Path(root).joinpath(".ready").touch()
+
+
+def prompt_download(file_, source, target_dir, content_dir=None):
+ targetpath = os.path.join(target_dir, file_)
+ while not os.path.exists(targetpath):
+ if content_dir is not None and os.path.exists(
+ os.path.join(target_dir, content_dir)
+ ):
+ break
+ print(
+ "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
+ )
+ if content_dir is not None:
+ print(
+ "Or place its content into '{}'.".format(
+ os.path.join(target_dir, content_dir)
+ )
+ )
+ input("Press Enter when done...")
+ return targetpath
+
+
+def download_url(file_, url, target_dir):
+ targetpath = os.path.join(target_dir, file_)
+ os.makedirs(target_dir, exist_ok=True)
+ with tqdm(
+ unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
+ ) as bar:
+ urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
+ return targetpath
+
+
+def download_urls(urls, target_dir):
+ paths = dict()
+ for fname, url in urls.items():
+ outpath = download_url(fname, url, target_dir)
+ paths[fname] = outpath
+ return paths
+
+
+def quadratic_crop(x, bbox, alpha=1.0):
+ """bbox is xmin, ymin, xmax, ymax"""
+ im_h, im_w = x.shape[:2]
+ bbox = np.array(bbox, dtype=np.float32)
+ bbox = np.clip(bbox, 0, max(im_h, im_w))
+ center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
+ w = bbox[2] - bbox[0]
+ h = bbox[3] - bbox[1]
+ l = int(alpha * max(w, h))
+ l = max(l, 2)
+
+ required_padding = -1 * min(
+ center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
+ )
+ required_padding = int(np.ceil(required_padding))
+ if required_padding > 0:
+ padding = [
+ [required_padding, required_padding],
+ [required_padding, required_padding],
+ ]
+ padding += [[0, 0]] * (len(x.shape) - 2)
+ x = np.pad(x, padding, "reflect")
+ center = center[0] + required_padding, center[1] + required_padding
+ xmin = int(center[0] - l / 2)
+ ymin = int(center[1] - l / 2)
+ return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
+
+
+def custom_collate(batch):
+ r"""source: pytorch 1.9.0, only one modification to original code """
+
+ elem = batch[0]
+ elem_type = type(elem)
+ if isinstance(elem, torch.Tensor):
+ out = None
+ if torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum([x.numel() for x in batch])
+ storage = elem.storage()._new_shared(numel)
+ out = elem.new(storage)
+ return torch.stack(batch, 0, out=out)
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+ and elem_type.__name__ != 'string_':
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
+ # array of string classes and object
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+
+ return custom_collate([torch.as_tensor(b) for b in batch])
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float64)
+ elif isinstance(elem, int):
+ return torch.tensor(batch)
+ elif isinstance(elem, string_classes):
+ return batch
+ elif isinstance(elem, collections.abc.Mapping):
+ return {key: custom_collate([d[key] for d in batch]) for key in elem}
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
+ return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
+ if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
+ return batch # added
+ elif isinstance(elem, collections.abc.Sequence):
+ # check to make sure that the elements in batch have consistent size
+ it = iter(batch)
+ elem_size = len(next(it))
+ if not all(len(elem) == elem_size for elem in it):
+ raise RuntimeError('each element in list of batch should be of equal size')
+ transposed = zip(*batch)
+ return [custom_collate(samples) for samples in transposed]
+
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
diff --git a/repositories/taming-transformers/taming/lr_scheduler.py b/repositories/taming-transformers/taming/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..e598ed120159c53da6820a55ad86b89f5c70c82d
--- /dev/null
+++ b/repositories/taming-transformers/taming/lr_scheduler.py
@@ -0,0 +1,34 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi))
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n):
+ return self.schedule(n)
+
diff --git a/repositories/taming-transformers/taming/models/cond_transformer.py b/repositories/taming-transformers/taming/models/cond_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4c63730fa86ac1b92b37af14c14fb696595b1ab
--- /dev/null
+++ b/repositories/taming-transformers/taming/models/cond_transformer.py
@@ -0,0 +1,352 @@
+import os, math
+import torch
+import torch.nn.functional as F
+import pytorch_lightning as pl
+
+from main import instantiate_from_config
+from taming.modules.util import SOSProvider
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class Net2NetTransformer(pl.LightningModule):
+ def __init__(self,
+ transformer_config,
+ first_stage_config,
+ cond_stage_config,
+ permuter_config=None,
+ ckpt_path=None,
+ ignore_keys=[],
+ first_stage_key="image",
+ cond_stage_key="depth",
+ downsample_cond_size=-1,
+ pkeep=1.0,
+ sos_token=0,
+ unconditional=False,
+ ):
+ super().__init__()
+ self.be_unconditional = unconditional
+ self.sos_token = sos_token
+ self.first_stage_key = first_stage_key
+ self.cond_stage_key = cond_stage_key
+ self.init_first_stage_from_ckpt(first_stage_config)
+ self.init_cond_stage_from_ckpt(cond_stage_config)
+ if permuter_config is None:
+ permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
+ self.permuter = instantiate_from_config(config=permuter_config)
+ self.transformer = instantiate_from_config(config=transformer_config)
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.downsample_cond_size = downsample_cond_size
+ self.pkeep = pkeep
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ for k in sd.keys():
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ self.print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def init_first_stage_from_ckpt(self, config):
+ model = instantiate_from_config(config)
+ model = model.eval()
+ model.train = disabled_train
+ self.first_stage_model = model
+
+ def init_cond_stage_from_ckpt(self, config):
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__" or self.be_unconditional:
+ print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
+ f"Prepending {self.sos_token} as a sos token.")
+ self.be_unconditional = True
+ self.cond_stage_key = self.first_stage_key
+ self.cond_stage_model = SOSProvider(self.sos_token)
+ else:
+ model = instantiate_from_config(config)
+ model = model.eval()
+ model.train = disabled_train
+ self.cond_stage_model = model
+
+ def forward(self, x, c):
+ # one step to produce the logits
+ _, z_indices = self.encode_to_z(x)
+ _, c_indices = self.encode_to_c(c)
+
+ if self.training and self.pkeep < 1.0:
+ mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
+ device=z_indices.device))
+ mask = mask.round().to(dtype=torch.int64)
+ r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
+ a_indices = mask*z_indices+(1-mask)*r_indices
+ else:
+ a_indices = z_indices
+
+ cz_indices = torch.cat((c_indices, a_indices), dim=1)
+
+ # target includes all sequence elements (no need to handle first one
+ # differently because we are conditioning)
+ target = z_indices
+ # make the prediction
+ logits, _ = self.transformer(cz_indices[:, :-1])
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1:
+ c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
+ quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
+ if len(indices.shape) > 2:
+ indices = indices.view(c.shape[0], -1)
+ return quant_c, indices
+
+ @torch.no_grad()
+ def decode_to_img(self, index, zshape):
+ index = self.permuter(index, reverse=True)
+ bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
+ quant_z = self.first_stage_model.quantize.get_codebook_entry(
+ index.reshape(-1), shape=bhwc)
+ x = self.first_stage_model.decode(quant_z)
+ return x
+
+ @torch.no_grad()
+ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
+ log = dict()
+
+ N = 4
+ if lr_interface:
+ x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
+ else:
+ x, c = self.get_xc(batch, N)
+ x = x.to(device=self.device)
+ c = c.to(device=self.device)
+
+ quant_z, z_indices = self.encode_to_z(x)
+ quant_c, c_indices = self.encode_to_c(c)
+
+ # create a "half"" sample
+ z_start_indices = z_indices[:,:z_indices.shape[1]//2]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1]-z_start_indices.shape[1],
+ temperature=temperature if temperature is not None else 1.0,
+ sample=True,
+ top_k=top_k if top_k is not None else 100,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample = self.decode_to_img(index_sample, quant_z.shape)
+
+ # sample
+ z_start_indices = z_indices[:, :0]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1],
+ temperature=temperature if temperature is not None else 1.0,
+ sample=True,
+ top_k=top_k if top_k is not None else 100,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
+
+ # det sample
+ z_start_indices = z_indices[:, :0]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1],
+ sample=False,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
+
+ # reconstruction
+ x_rec = self.decode_to_img(z_indices, quant_z.shape)
+
+ log["inputs"] = x
+ log["reconstructions"] = x_rec
+
+ if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
+ figure_size = (x_rec.shape[2], x_rec.shape[3])
+ dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
+ label_for_category_no = dataset.get_textual_label_for_category_no
+ plotter = dataset.conditional_builders[self.cond_stage_key].plot
+ log["conditioning"] = torch.zeros_like(log["reconstructions"])
+ for i in range(quant_c.shape[0]):
+ log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
+ log["conditioning_rec"] = log["conditioning"]
+ elif self.cond_stage_key != "image":
+ cond_rec = self.cond_stage_model.decode(quant_c)
+ if self.cond_stage_key == "segmentation":
+ # get image from segmentation mask
+ num_classes = cond_rec.shape[1]
+
+ c = torch.argmax(c, dim=1, keepdim=True)
+ c = F.one_hot(c, num_classes=num_classes)
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
+ c = self.cond_stage_model.to_rgb(c)
+
+ cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
+ cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
+ cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
+ cond_rec = self.cond_stage_model.to_rgb(cond_rec)
+ log["conditioning_rec"] = cond_rec
+ log["conditioning"] = c
+
+ log["samples_half"] = x_sample
+ log["samples_nopix"] = x_sample_nopix
+ log["samples_det"] = x_sample_det
+ return log
+
+ def get_input(self, key, batch):
+ x = batch[key]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ if len(x.shape) == 4:
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+ if x.dtype == torch.double:
+ x = x.float()
+ return x
+
+ def get_xc(self, batch, N=None):
+ x = self.get_input(self.first_stage_key, batch)
+ c = self.get_input(self.cond_stage_key, batch)
+ if N is not None:
+ x = x[:N]
+ c = c[:N]
+ return x, c
+
+ def shared_step(self, batch, batch_idx):
+ x, c = self.get_xc(batch)
+ logits, target = self(x, c)
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
+ return loss
+
+ def training_step(self, batch, batch_idx):
+ loss = self.shared_step(batch, batch_idx)
+ self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ loss = self.shared_step(batch, batch_idx)
+ self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ def configure_optimizers(self):
+ """
+ Following minGPT:
+ This long function is unfortunately doing something very simple and is being very defensive:
+ We are separating out all parameters of the model into two buckets: those that will experience
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
+ We are then returning the PyTorch optimizer object.
+ """
+ # separate out all parameters to those that will and won't experience regularizing weight decay
+ decay = set()
+ no_decay = set()
+ whitelist_weight_modules = (torch.nn.Linear, )
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
+ for mn, m in self.transformer.named_modules():
+ for pn, p in m.named_parameters():
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
+
+ if pn.endswith('bias'):
+ # all biases will not be decayed
+ no_decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
+ # weights of whitelist modules will be weight decayed
+ decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
+ # weights of blacklist modules will NOT be weight decayed
+ no_decay.add(fpn)
+
+ # special case the position embedding parameter in the root GPT module as not decayed
+ no_decay.add('pos_emb')
+
+ # validate that we considered every parameter
+ param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
+ inter_params = decay & no_decay
+ union_params = decay | no_decay
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
+ % (str(param_dict.keys() - union_params), )
+
+ # create the pytorch optimizer object
+ optim_groups = [
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
+ ]
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
+ return optimizer
diff --git a/repositories/taming-transformers/taming/models/dummy_cond_stage.py b/repositories/taming-transformers/taming/models/dummy_cond_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e19938078752e09b926a3e749907ee99a258ca0
--- /dev/null
+++ b/repositories/taming-transformers/taming/models/dummy_cond_stage.py
@@ -0,0 +1,22 @@
+from torch import Tensor
+
+
+class DummyCondStage:
+ def __init__(self, conditional_key):
+ self.conditional_key = conditional_key
+ self.train = None
+
+ def eval(self):
+ return self
+
+ @staticmethod
+ def encode(c: Tensor):
+ return c, None, (None, None, c)
+
+ @staticmethod
+ def decode(c: Tensor):
+ return c
+
+ @staticmethod
+ def to_rgb(c: Tensor):
+ return c
diff --git a/repositories/taming-transformers/taming/models/vqgan.py b/repositories/taming-transformers/taming/models/vqgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6950baa5f739111cd64c17235dca8be3a5f8037
--- /dev/null
+++ b/repositories/taming-transformers/taming/models/vqgan.py
@@ -0,0 +1,404 @@
+import torch
+import torch.nn.functional as F
+import pytorch_lightning as pl
+
+from main import instantiate_from_config
+
+from taming.modules.diffusionmodules.model import Encoder, Decoder
+from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+from taming.modules.vqvae.quantize import GumbelQuantize
+from taming.modules.vqvae.quantize import EMAVectorQuantizer
+
+class VQModel(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ ):
+ super().__init__()
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
+ remap=remap, sane_index_shape=sane_index_shape)
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.image_key = image_key
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h)
+ return quant, emb_loss, info
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b):
+ quant_b = self.quantize.embed_code(code_b)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward(self, input):
+ quant, diff, _ = self.encode(input)
+ dec = self.decode(quant)
+ return dec, diff
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+ return x.float()
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+ rec_loss = log_dict_ae["val/rec_loss"]
+ self.log("val/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+ self.log("val/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ def log_images(self, batch, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class VQSegmentationModel(VQModel):
+ def __init__(self, n_labels, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ return opt_ae
+
+ def training_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ total_loss = log_dict_ae["val/total_loss"]
+ self.log("val/total_loss", total_loss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+ return aeloss
+
+ @torch.no_grad()
+ def log_images(self, batch, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ # convert logits to indices
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ return log
+
+
+class VQNoDiscModel(VQModel):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None
+ ):
+ super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
+ ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
+ colorize_nlabels=colorize_nlabels)
+
+ def training_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
+ output = pl.TrainResult(minimize=aeloss)
+ output.log("train/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return output
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
+ rec_loss = log_dict_ae["val/rec_loss"]
+ output = pl.EvalResult(checkpoint_on=rec_loss)
+ output.log("val/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ output.log("val/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ output.log_dict(log_dict_ae)
+
+ return output
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=self.learning_rate, betas=(0.5, 0.9))
+ return optimizer
+
+
+class GumbelVQ(VQModel):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ temperature_scheduler_config,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ kl_weight=1e-8,
+ remap=None,
+ ):
+
+ z_channels = ddconfig["z_channels"]
+ super().__init__(ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=ignore_keys,
+ image_key=image_key,
+ colorize_nlabels=colorize_nlabels,
+ monitor=monitor,
+ )
+
+ self.loss.n_classes = n_embed
+ self.vocab_size = n_embed
+
+ self.quantize = GumbelQuantize(z_channels, embed_dim,
+ n_embed=n_embed,
+ kl_weight=kl_weight, temp_init=1.0,
+ remap=remap)
+
+ self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def temperature_scheduling(self):
+ self.quantize.temperature = self.temperature_scheduler(self.global_step)
+
+ def encode_to_prequant(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode_code(self, code_b):
+ raise NotImplementedError
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ self.temperature_scheduling()
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x, return_pred_indices=True)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+ rec_loss = log_dict_ae["val/rec_loss"]
+ self.log("val/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log("val/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def log_images(self, batch, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ # encode
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, _, _ = self.quantize(h)
+ # decode
+ x_rec = self.decode(quant)
+ log["inputs"] = x
+ log["reconstructions"] = x_rec
+ return log
+
+
+class EMAVQ(VQModel):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ ):
+ super().__init__(ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=ignore_keys,
+ image_key=image_key,
+ colorize_nlabels=colorize_nlabels,
+ monitor=monitor,
+ )
+ self.quantize = EMAVectorQuantizer(n_embed=n_embed,
+ embedding_dim=embed_dim,
+ beta=0.25,
+ remap=remap)
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ #Remove self.quantize from parameter list since it is updated via EMA
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
\ No newline at end of file
diff --git a/repositories/taming-transformers/taming/modules/diffusionmodules/model.py b/repositories/taming-transformers/taming/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a5db6aa2ef915e270f1ae135e4a9918fdd884c
--- /dev/null
+++ b/repositories/taming-transformers/taming/modules/diffusionmodules/model.py
@@ -0,0 +1,776 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x, t=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, **ignore_kwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x):
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
+
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class VUNet(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ in_channels, c_channels,
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(c_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ self.z_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x, z):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ z = self.z_in(z)
+ h = torch.cat((h,z),dim=1)
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
diff --git a/repositories/taming-transformers/taming/modules/discriminator/model.py b/repositories/taming-transformers/taming/modules/discriminator/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2aaa3110d0a7bcd05de7eca1e45101589ca5af05
--- /dev/null
+++ b/repositories/taming-transformers/taming/modules/discriminator/model.py
@@ -0,0 +1,67 @@
+import functools
+import torch.nn as nn
+
+
+from taming.modules.util import ActNorm
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find('BatchNorm') != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm2d
+ else:
+ norm_layer = ActNorm
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
diff --git a/repositories/taming-transformers/taming/modules/losses/__init__.py b/repositories/taming-transformers/taming/modules/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d09caf9eb805f849a517f1b23503e1a4d6ea1ec5
--- /dev/null
+++ b/repositories/taming-transformers/taming/modules/losses/__init__.py
@@ -0,0 +1,2 @@
+from taming.modules.losses.vqperceptual import DummyLoss
+
diff --git a/repositories/taming-transformers/taming/modules/losses/lpips.py b/repositories/taming-transformers/taming/modules/losses/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7280447694ffc302a7636e7e4d6183408e0aa95
--- /dev/null
+++ b/repositories/taming-transformers/taming/modules/losses/lpips.py
@@ -0,0 +1,123 @@
+"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
+
+import torch
+import torch.nn as nn
+from torchvision import models
+from collections import namedtuple
+
+from taming.util import get_ckpt_path
+
+
+class LPIPS(nn.Module):
+ # Learned perceptual metric
+ def __init__(self, use_dropout=True):
+ super().__init__()
+ self.scaling_layer = ScalingLayer()
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
+ self.net = vgg16(pretrained=True, requires_grad=False)
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.load_from_pretrained()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def load_from_pretrained(self, name="vgg_lpips"):
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
+
+ @classmethod
+ def from_pretrained(cls, name="vgg_lpips"):
+ if name != "vgg_lpips":
+ raise NotImplementedError
+ model = cls()
+ ckpt = get_ckpt_path(name)
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ return model
+
+ def forward(self, input, target):
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+ for kk in range(len(self.chns)):
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
+ val = res[0]
+ for l in range(1, len(self.chns)):
+ val += res[l]
+ return val
+
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+
+class NetLinLayer(nn.Module):
+ """ A single linear layer which does a 1x1 conv """
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+ layers = [nn.Dropout(), ] if (use_dropout) else []
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
+ self.model = nn.Sequential(*layers)
+
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+ return out
+
+
+def normalize_tensor(x,eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
+ return x/(norm_factor+eps)
+
+
+def spatial_average(x, keepdim=True):
+ return x.mean([2,3],keepdim=keepdim)
+
diff --git a/repositories/taming-transformers/taming/modules/losses/segmentation.py b/repositories/taming-transformers/taming/modules/losses/segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ba77deb5159a6307ed2acba9945e4764a4ff0a5
--- /dev/null
+++ b/repositories/taming-transformers/taming/modules/losses/segmentation.py
@@ -0,0 +1,22 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BCELoss(nn.Module):
+ def forward(self, prediction, target):
+ loss = F.binary_cross_entropy_with_logits(prediction,target)
+ return loss, {}
+
+
+class BCELossWithQuant(nn.Module):
+ def __init__(self, codebook_weight=1.):
+ super().__init__()
+ self.codebook_weight = codebook_weight
+
+ def forward(self, qloss, target, prediction, split):
+ bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
+ loss = bce_loss + self.codebook_weight*qloss
+ return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/bce_loss".format(split): bce_loss.detach().mean(),
+ "{}/quant_loss".format(split): qloss.detach().mean()
+ }
diff --git a/repositories/taming-transformers/taming/modules/losses/vqperceptual.py b/repositories/taming-transformers/taming/modules/losses/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2febd445728479d4cd9aacdb2572cb1f1af04db
--- /dev/null
+++ b/repositories/taming-transformers/taming/modules/losses/vqperceptual.py
@@ -0,0 +1,136 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from taming.modules.losses.lpips import LPIPS
+from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
+
+
+class DummyLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1. - logits_real))
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def vanilla_d_loss(logits_real, logits_fake):
+ d_loss = 0.5 * (
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
+ return d_loss
+
+
+class VQLPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_ndf=64, disc_loss="hinge"):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.codebook_weight = codebook_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm,
+ ndf=disc_ndf
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ if disc_loss == "hinge":
+ self.disc_loss = hinge_d_loss
+ elif disc_loss == "vanilla":
+ self.disc_loss = vanilla_d_loss
+ else:
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train"):
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ else:
+ p_loss = torch.tensor([0.0])
+
+ nll_loss = rec_loss
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/p_loss".format(split): p_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
diff --git a/repositories/taming-transformers/taming/modules/misc/coord.py b/repositories/taming-transformers/taming/modules/misc/coord.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee69b0c897b6b382ae673622e420f55e494f5b09
--- /dev/null
+++ b/repositories/taming-transformers/taming/modules/misc/coord.py
@@ -0,0 +1,31 @@
+import torch
+
+class CoordStage(object):
+ def __init__(self, n_embed, down_factor):
+ self.n_embed = n_embed
+ self.down_factor = down_factor
+
+ def eval(self):
+ return self
+
+ def encode(self, c):
+ """fake vqmodel interface"""
+ assert 0.0 <= c.min() and c.max() <= 1.0
+ b,ch,h,w = c.shape
+ assert ch == 1
+
+ c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
+ mode="area")
+ c = c.clamp(0.0, 1.0)
+ c = self.n_embed*c
+ c_quant = c.round()
+ c_ind = c_quant.to(dtype=torch.long)
+
+ info = None, None, c_ind
+ return c_quant, None, info
+
+ def decode(self, c):
+ c = c/self.n_embed
+ c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
+ mode="nearest")
+ return c
diff --git a/repositories/taming-transformers/taming/modules/transformer/mingpt.py b/repositories/taming-transformers/taming/modules/transformer/mingpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..d14b7b68117f4b9f297b2929397cd4f55089334c
--- /dev/null
+++ b/repositories/taming-transformers/taming/modules/transformer/mingpt.py
@@ -0,0 +1,415 @@
+"""
+taken from: https://github.com/karpathy/minGPT/
+GPT model:
+- the initial stem consists of a combination of token encoding and a positional encoding
+- the meat of it is a uniform sequence of Transformer blocks
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
+ - all blocks feed into a central residual pathway similar to resnets
+- the final decoder is a linear projection into a vanilla Softmax classifier
+"""
+
+import math
+import logging
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from transformers import top_k_top_p_filtering
+
+logger = logging.getLogger(__name__)
+
+
+class GPTConfig:
+ """ base GPT config, params common to all GPT versions """
+ embd_pdrop = 0.1
+ resid_pdrop = 0.1
+ attn_pdrop = 0.1
+
+ def __init__(self, vocab_size, block_size, **kwargs):
+ self.vocab_size = vocab_size
+ self.block_size = block_size
+ for k,v in kwargs.items():
+ setattr(self, k, v)
+
+
+class GPT1Config(GPTConfig):
+ """ GPT-1 like network roughly 125M params """
+ n_layer = 12
+ n_head = 12
+ n_embd = 768
+
+
+class CausalSelfAttention(nn.Module):
+ """
+ A vanilla multi-head masked self-attention layer with a projection at the end.
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
+ explicit implementation here to show that there is nothing too scary here.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ assert config.n_embd % config.n_head == 0
+ # key, query, value projections for all heads
+ self.key = nn.Linear(config.n_embd, config.n_embd)
+ self.query = nn.Linear(config.n_embd, config.n_embd)
+ self.value = nn.Linear(config.n_embd, config.n_embd)
+ # regularization
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
+ # output projection
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
+ # causal mask to ensure that attention is only applied to the left in the input sequence
+ mask = torch.tril(torch.ones(config.block_size,
+ config.block_size))
+ if hasattr(config, "n_unmasked"):
+ mask[:config.n_unmasked, :config.n_unmasked] = 1
+ self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
+ self.n_head = config.n_head
+
+ def forward(self, x, layer_past=None):
+ B, T, C = x.size()
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+
+ present = torch.stack((k, v))
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ k = torch.cat((past_key, k), dim=-2)
+ v = torch.cat((past_value, v), dim=-2)
+
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+ if layer_past is None:
+ att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
+
+ att = F.softmax(att, dim=-1)
+ att = self.attn_drop(att)
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
+
+ # output projection
+ y = self.resid_drop(self.proj(y))
+ return y, present # TODO: check that this does not break anything
+
+
+class Block(nn.Module):
+ """ an unassuming Transformer block """
+ def __init__(self, config):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(config.n_embd)
+ self.ln2 = nn.LayerNorm(config.n_embd)
+ self.attn = CausalSelfAttention(config)
+ self.mlp = nn.Sequential(
+ nn.Linear(config.n_embd, 4 * config.n_embd),
+ nn.GELU(), # nice
+ nn.Linear(4 * config.n_embd, config.n_embd),
+ nn.Dropout(config.resid_pdrop),
+ )
+
+ def forward(self, x, layer_past=None, return_present=False):
+ # TODO: check that training still works
+ if return_present: assert not self.training
+ # layer past: tuple of length two with B, nh, T, hs
+ attn, present = self.attn(self.ln1(x), layer_past=layer_past)
+
+ x = x + attn
+ x = x + self.mlp(self.ln2(x))
+ if layer_past is not None or return_present:
+ return x, present
+ return x
+
+
+class GPT(nn.Module):
+ """ the full GPT language model, with a context size of block_size """
+ def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
+ super().__init__()
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
+ n_unmasked=n_unmasked)
+ # input embedding stem
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+ self.drop = nn.Dropout(config.embd_pdrop)
+ # transformer
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
+ # decoder head
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+ self.block_size = config.block_size
+ self.apply(self._init_weights)
+ self.config = config
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
+
+ def get_block_size(self):
+ return self.block_size
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, idx, embeddings=None, targets=None):
+ # forward the GPT model
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+
+ if embeddings is not None: # prepend explicit embeddings
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+ t = token_embeddings.shape[1]
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
+ x = self.drop(token_embeddings + position_embeddings)
+ x = self.blocks(x)
+ x = self.ln_f(x)
+ logits = self.head(x)
+
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss
+
+ def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
+ # inference only
+ assert not self.training
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+ if embeddings is not None: # prepend explicit embeddings
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+ if past is not None:
+ assert past_length is not None
+ past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
+ past_shape = list(past.shape)
+ expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
+ assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
+ position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
+ else:
+ position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
+
+ x = self.drop(token_embeddings + position_embeddings)
+ presents = [] # accumulate over layers
+ for i, block in enumerate(self.blocks):
+ x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
+ presents.append(present)
+
+ x = self.ln_f(x)
+ logits = self.head(x)
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
+
+
+class DummyGPT(nn.Module):
+ # for debugging
+ def __init__(self, add_value=1):
+ super().__init__()
+ self.add_value = add_value
+
+ def forward(self, idx):
+ return idx + self.add_value, None
+
+
+class CodeGPT(nn.Module):
+ """Takes in semi-embeddings"""
+ def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
+ super().__init__()
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
+ n_unmasked=n_unmasked)
+ # input embedding stem
+ self.tok_emb = nn.Linear(in_channels, config.n_embd)
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+ self.drop = nn.Dropout(config.embd_pdrop)
+ # transformer
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
+ # decoder head
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+ self.block_size = config.block_size
+ self.apply(self._init_weights)
+ self.config = config
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
+
+ def get_block_size(self):
+ return self.block_size
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, idx, embeddings=None, targets=None):
+ # forward the GPT model
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+
+ if embeddings is not None: # prepend explicit embeddings
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+ t = token_embeddings.shape[1]
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
+ x = self.drop(token_embeddings + position_embeddings)
+ x = self.blocks(x)
+ x = self.taming_cinln_f(x)
+ logits = self.head(x)
+
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss
+
+
+
+#### sampling utils
+
+def top_k_logits(logits, k):
+ v, ix = torch.topk(logits, k)
+ out = logits.clone()
+ out[out < v[:, [-1]]] = -float('Inf')
+ return out
+
+@torch.no_grad()
+def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
+ """
+ take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
+ of block_size, unlike an RNN that has an infinite context window.
+ """
+ block_size = model.get_block_size()
+ model.eval()
+ for k in range(steps):
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
+ logits, _ = model(x_cond)
+ # pluck the logits at the final step and scale by temperature
+ logits = logits[:, -1, :] / temperature
+ # optionally crop probabilities to only the top k options
+ if top_k is not None:
+ logits = top_k_logits(logits, top_k)
+ # apply softmax to convert to probabilities
+ probs = F.softmax(logits, dim=-1)
+ # sample from the distribution or take the most likely
+ if sample:
+ ix = torch.multinomial(probs, num_samples=1)
+ else:
+ _, ix = torch.topk(probs, k=1, dim=-1)
+ # append to the sequence and continue
+ x = torch.cat((x, ix), dim=1)
+
+ return x
+
+
+@torch.no_grad()
+def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
+ top_k=None, top_p=None, callback=None):
+ # x is conditioning
+ sample = x
+ cond_len = x.shape[1]
+ past = None
+ for n in range(steps):
+ if callback is not None:
+ callback(n)
+ logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
+ if past is None:
+ past = [present]
+ else:
+ past.append(present)
+ logits = logits[:, -1, :] / temperature
+ if top_k is not None:
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
+
+ probs = F.softmax(logits, dim=-1)
+ if not sample_logits:
+ _, x = torch.topk(probs, k=1, dim=-1)
+ else:
+ x = torch.multinomial(probs, num_samples=1)
+ # append to the sequence and continue
+ sample = torch.cat((sample, x), dim=1)
+ del past
+ sample = sample[:, cond_len:] # cut conditioning off
+ return sample
+
+
+#### clustering utils
+
+class KMeans(nn.Module):
+ def __init__(self, ncluster=512, nc=3, niter=10):
+ super().__init__()
+ self.ncluster = ncluster
+ self.nc = nc
+ self.niter = niter
+ self.shape = (3,32,32)
+ self.register_buffer("C", torch.zeros(self.ncluster,nc))
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
+
+ def is_initialized(self):
+ return self.initialized.item() == 1
+
+ @torch.no_grad()
+ def initialize(self, x):
+ N, D = x.shape
+ assert D == self.nc, D
+ c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
+ for i in range(self.niter):
+ # assign all pixels to the closest codebook element
+ a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
+ # move each codebook element to be the mean of the pixels that assigned to it
+ c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
+ # re-assign any poorly positioned codebook elements
+ nanix = torch.any(torch.isnan(c), dim=1)
+ ndead = nanix.sum().item()
+ print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
+ c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
+
+ self.C.copy_(c)
+ self.initialized.fill_(1)
+
+
+ def forward(self, x, reverse=False, shape=None):
+ if not reverse:
+ # flatten
+ bs,c,h,w = x.shape
+ assert c == self.nc
+ x = x.reshape(bs,c,h*w,1)
+ C = self.C.permute(1,0)
+ C = C.reshape(1,c,1,self.ncluster)
+ a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
+ return a
+ else:
+ # flatten
+ bs, HW = x.shape
+ """
+ c = self.C.reshape( 1, self.nc, 1, self.ncluster)
+ c = c[bs*[0],:,:,:]
+ c = c[:,:,HW*[0],:]
+ x = x.reshape(bs, 1, HW, 1)
+ x = x[:,3*[0],:,:]
+ x = torch.gather(c, dim=3, index=x)
+ """
+ x = self.C[x]
+ x = x.permute(0,2,1)
+ shape = shape if shape is not None else self.shape
+ x = x.reshape(bs, *shape)
+
+ return x
diff --git a/repositories/taming-transformers/taming/modules/transformer/permuter.py b/repositories/taming-transformers/taming/modules/transformer/permuter.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d43bb135adde38d94bf18a7e5edaa4523cd95cf
--- /dev/null
+++ b/repositories/taming-transformers/taming/modules/transformer/permuter.py
@@ -0,0 +1,248 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+class AbstractPermuter(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ def forward(self, x, reverse=False):
+ raise NotImplementedError
+
+
+class Identity(AbstractPermuter):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, reverse=False):
+ return x
+
+
+class Subsample(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ C = 1
+ indices = np.arange(H*W).reshape(C,H,W)
+ while min(H, W) > 1:
+ indices = indices.reshape(C,H//2,2,W//2,2)
+ indices = indices.transpose(0,2,4,1,3)
+ indices = indices.reshape(C*4,H//2, W//2)
+ H = H//2
+ W = W//2
+ C = C*4
+ assert H == W == 1
+ idx = torch.tensor(indices.ravel())
+ self.register_buffer('forward_shuffle_idx',
+ nn.Parameter(idx, requires_grad=False))
+ self.register_buffer('backward_shuffle_idx',
+ nn.Parameter(torch.argsort(idx), requires_grad=False))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+def mortonify(i, j):
+ """(i,j) index to linear morton code"""
+ i = np.uint64(i)
+ j = np.uint64(j)
+
+ z = np.uint(0)
+
+ for pos in range(32):
+ z = (z |
+ ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
+ ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
+ )
+ return z
+
+
+class ZCurve(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
+ idx = np.argsort(reverseidx)
+ idx = torch.tensor(idx)
+ reverseidx = torch.tensor(reverseidx)
+ self.register_buffer('forward_shuffle_idx',
+ idx)
+ self.register_buffer('backward_shuffle_idx',
+ reverseidx)
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class SpiralOut(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ assert H == W
+ size = W
+ indices = np.arange(size*size).reshape(size,size)
+
+ i0 = size//2
+ j0 = size//2-1
+
+ i = i0
+ j = j0
+
+ idx = [indices[i0, j0]]
+ step_mult = 0
+ for c in range(1, size//2+1):
+ step_mult += 1
+ # steps left
+ for k in range(step_mult):
+ i = i - 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step down
+ for k in range(step_mult):
+ i = i
+ j = j + 1
+ idx.append(indices[i, j])
+
+ step_mult += 1
+ if c < size//2:
+ # step right
+ for k in range(step_mult):
+ i = i + 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step up
+ for k in range(step_mult):
+ i = i
+ j = j - 1
+ idx.append(indices[i, j])
+ else:
+ # end reached
+ for k in range(step_mult-1):
+ i = i + 1
+ idx.append(indices[i, j])
+
+ assert len(idx) == size*size
+ idx = torch.tensor(idx)
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class SpiralIn(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ assert H == W
+ size = W
+ indices = np.arange(size*size).reshape(size,size)
+
+ i0 = size//2
+ j0 = size//2-1
+
+ i = i0
+ j = j0
+
+ idx = [indices[i0, j0]]
+ step_mult = 0
+ for c in range(1, size//2+1):
+ step_mult += 1
+ # steps left
+ for k in range(step_mult):
+ i = i - 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step down
+ for k in range(step_mult):
+ i = i
+ j = j + 1
+ idx.append(indices[i, j])
+
+ step_mult += 1
+ if c < size//2:
+ # step right
+ for k in range(step_mult):
+ i = i + 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step up
+ for k in range(step_mult):
+ i = i
+ j = j - 1
+ idx.append(indices[i, j])
+ else:
+ # end reached
+ for k in range(step_mult-1):
+ i = i + 1
+ idx.append(indices[i, j])
+
+ assert len(idx) == size*size
+ idx = idx[::-1]
+ idx = torch.tensor(idx)
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class Random(nn.Module):
+ def __init__(self, H, W):
+ super().__init__()
+ indices = np.random.RandomState(1).permutation(H*W)
+ idx = torch.tensor(indices.ravel())
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class AlternateParsing(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ indices = np.arange(W*H).reshape(H,W)
+ for i in range(1, H, 2):
+ indices[i, :] = indices[i, ::-1]
+ idx = indices.flatten()
+ assert len(idx) == H*W
+ idx = torch.tensor(idx)
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+if __name__ == "__main__":
+ p0 = AlternateParsing(16, 16)
+ print(p0.forward_shuffle_idx)
+ print(p0.backward_shuffle_idx)
+
+ x = torch.randint(0, 768, size=(11, 256))
+ y = p0(x)
+ xre = p0(y, reverse=True)
+ assert torch.equal(x, xre)
+
+ p1 = SpiralOut(2, 2)
+ print(p1.forward_shuffle_idx)
+ print(p1.backward_shuffle_idx)
diff --git a/repositories/taming-transformers/taming/modules/util.py b/repositories/taming-transformers/taming/modules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ee16385d8b1342a2d60a5f1aa5cadcfbe934bd8
--- /dev/null
+++ b/repositories/taming-transformers/taming/modules/util.py
@@ -0,0 +1,130 @@
+import torch
+import torch.nn as nn
+
+
+def count_params(model):
+ total_params = sum(p.numel() for p in model.parameters())
+ return total_params
+
+
+class ActNorm(nn.Module):
+ def __init__(self, num_features, logdet=False, affine=True,
+ allow_reverse_init=False):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = (
+ flatten.mean(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+ std = (
+ flatten.std(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:,:,None,None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height*width*torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:,:,None,None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class Labelator(AbstractEncoder):
+ """Net2Net Interface for Class-Conditional Model"""
+ def __init__(self, n_classes, quantize_interface=True):
+ super().__init__()
+ self.n_classes = n_classes
+ self.quantize_interface = quantize_interface
+
+ def encode(self, c):
+ c = c[:,None]
+ if self.quantize_interface:
+ return c, None, [None, None, c.long()]
+ return c
+
+
+class SOSProvider(AbstractEncoder):
+ # for unconditional training
+ def __init__(self, sos_token, quantize_interface=True):
+ super().__init__()
+ self.sos_token = sos_token
+ self.quantize_interface = quantize_interface
+
+ def encode(self, x):
+ # get batch size from data and replicate sos_token
+ c = torch.ones(x.shape[0], 1)*self.sos_token
+ c = c.long().to(x.device)
+ if self.quantize_interface:
+ return c, None, [None, None, c]
+ return c
diff --git a/repositories/taming-transformers/taming/modules/vqvae/quantize.py b/repositories/taming-transformers/taming/modules/vqvae/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..d75544e41fa01bce49dd822b1037963d62f79b51
--- /dev/null
+++ b/repositories/taming-transformers/taming/modules/vqvae/quantize.py
@@ -0,0 +1,445 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from torch import einsum
+from einops import rearrange
+
+
+class VectorQuantizer(nn.Module):
+ """
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
+ ____________________________________________
+ Discretization bottleneck part of the VQ-VAE.
+ Inputs:
+ - n_e : number of embeddings
+ - e_dim : dimension of embedding
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ _____________________________________________
+ """
+
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
+ # used wherever VectorQuantizer has been used before and is additionally
+ # more efficient.
+ def __init__(self, n_e, e_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ def forward(self, z):
+ """
+ Inputs the output of the encoder network z and maps it to a discrete
+ one-hot vector that is the index of the closest embedding vector e_j
+ z (continuous) -> z_q (discrete)
+ z.shape = (batch, channel, height, width)
+ quantization pipeline:
+ 1. get encoder input (B,C,H,W)
+ 2. flatten input to (B*H*W,C)
+ """
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.matmul(z_flattened, self.embedding.weight.t())
+
+ ## could possible replace this here
+ # #\start...
+ # find closest encodings
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+
+ min_encodings = torch.zeros(
+ min_encoding_indices.shape[0], self.n_e).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # dtype min encodings: torch.float32
+ # min_encodings shape: torch.Size([2048, 512])
+ # min_encoding_indices.shape: torch.Size([2048, 1])
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ #.........\end
+
+ # with:
+ # .........\start
+ #min_encoding_indices = torch.argmin(d, dim=1)
+ #z_q = self.embedding(min_encoding_indices)
+ # ......\end......... (TODO)
+
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ # TODO: check for more easy handling with nn.Embedding
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
+ min_encodings.scatter_(1, indices[:,None], 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantize(nn.Module):
+ """
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
+ Gumbel Softmax trick quantizer
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
+ https://arxiv.org/abs/1611.01144
+ """
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
+ remap=None, unknown_index="random"):
+ super().__init__()
+
+ self.embedding_dim = embedding_dim
+ self.n_embed = n_embed
+
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
+ self.embed = nn.Embedding(n_embed, embedding_dim)
+
+ self.use_vqinterface = use_vqinterface
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, return_logits=False):
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
+ hard = self.straight_through if self.training else True
+ temp = self.temperature if temp is None else temp
+
+ logits = self.proj(z)
+ if self.remap is not None:
+ # continue only with used logits
+ full_zeros = torch.zeros_like(logits)
+ logits = logits[:,self.used,...]
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
+ if self.remap is not None:
+ # go back to all entries but unused set to zero
+ full_zeros[:,self.used,...] = soft_one_hot
+ soft_one_hot = full_zeros
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
+
+ ind = soft_one_hot.argmax(dim=1)
+ if self.remap is not None:
+ ind = self.remap_to_used(ind)
+ if self.use_vqinterface:
+ if return_logits:
+ return z_q, diff, (None, None, ind), logits
+ return z_q, diff, (None, None, ind)
+ return z_q, diff, ind
+
+ def get_codebook_entry(self, indices, shape):
+ b, h, w, c = shape
+ assert b*h*w == indices.shape[0]
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
+ if self.remap is not None:
+ indices = self.unmap_to_all(indices)
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
+ return z_q
+
+
+class VectorQuantizer2(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
+ sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
+ assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits==False, "Only for interface compatible with Gumbel"
+ assert return_logits==False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
+ torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0],-1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+class EmbeddingEMA(nn.Module):
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
+ super().__init__()
+ self.decay = decay
+ self.eps = eps
+ weight = torch.randn(num_tokens, codebook_dim)
+ self.weight = nn.Parameter(weight, requires_grad = False)
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
+ self.update = True
+
+ def forward(self, embed_id):
+ return F.embedding(embed_id, self.weight)
+
+ def cluster_size_ema_update(self, new_cluster_size):
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
+
+ def embed_avg_ema_update(self, new_embed_avg):
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
+
+ def weight_update(self, num_tokens):
+ n = self.cluster_size.sum()
+ smoothed_cluster_size = (
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
+ )
+ #normalize embedding average with smoothed cluster size
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
+ self.weight.data.copy_(embed_normalized)
+
+
+class EMAVectorQuantizer(nn.Module):
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
+ remap=None, unknown_index="random"):
+ super().__init__()
+ self.codebook_dim = codebook_dim
+ self.num_tokens = num_tokens
+ self.beta = beta
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ #z, 'b c h w -> b h w c'
+ z = rearrange(z, 'b c h w -> b h w c')
+ z_flattened = z.reshape(-1, self.codebook_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
+
+
+ encoding_indices = torch.argmin(d, dim=1)
+
+ z_q = self.embedding(encoding_indices).view(z.shape)
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ if self.training and self.embedding.update:
+ #EMA cluster size
+ encodings_sum = encodings.sum(0)
+ self.embedding.cluster_size_ema_update(encodings_sum)
+ #EMA embedding average
+ embed_sum = encodings.transpose(0,1) @ z_flattened
+ self.embedding.embed_avg_ema_update(embed_sum)
+ #normalize embed_avg and update weight
+ self.embedding.weight_update(self.num_tokens)
+
+ # compute loss for embedding
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ #z_q, 'b h w c -> b c h w'
+ z_q = rearrange(z_q, 'b h w c -> b c h w')
+ return z_q, loss, (perplexity, encodings, encoding_indices)
diff --git a/repositories/taming-transformers/taming/util.py b/repositories/taming-transformers/taming/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..06053e5defb87977f9ab07e69bf4da12201de9b7
--- /dev/null
+++ b/repositories/taming-transformers/taming/util.py
@@ -0,0 +1,157 @@
+import os, hashlib
+import requests
+from tqdm import tqdm
+
+URL_MAP = {
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
+}
+
+CKPT_MAP = {
+ "vgg_lpips": "vgg.pth"
+}
+
+MD5_MAP = {
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
+}
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+ assert name in URL_MAP
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
+
+
+class KeyNotFoundError(Exception):
+ def __init__(self, cause, keys=None, visited=None):
+ self.cause = cause
+ self.keys = keys
+ self.visited = visited
+ messages = list()
+ if keys is not None:
+ messages.append("Key not found: {}".format(keys))
+ if visited is not None:
+ messages.append("Visited: {}".format(visited))
+ messages.append("Cause:\n{}".format(cause))
+ message = "\n".join(messages)
+ super().__init__(message)
+
+
+def retrieve(
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
+):
+ """Given a nested list or dict return the desired value at key expanding
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
+ is done in-place.
+
+ Parameters
+ ----------
+ list_or_dict : list or dict
+ Possibly nested list or dictionary.
+ key : str
+ key/to/value, path like string describing all keys necessary to
+ consider to get to the desired value. List indices can also be
+ passed here.
+ splitval : str
+ String that defines the delimiter between keys of the
+ different depth levels in `key`.
+ default : obj
+ Value returned if :attr:`key` is not found.
+ expand : bool
+ Whether to expand callable nodes on the path or not.
+
+ Returns
+ -------
+ The desired value or if :attr:`default` is not ``None`` and the
+ :attr:`key` is not found returns ``default``.
+
+ Raises
+ ------
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
+ ``None``.
+ """
+
+ keys = key.split(splitval)
+
+ success = True
+ try:
+ visited = []
+ parent = None
+ last_key = None
+ for key in keys:
+ if callable(list_or_dict):
+ if not expand:
+ raise KeyNotFoundError(
+ ValueError(
+ "Trying to get past callable node with expand=False."
+ ),
+ keys=keys,
+ visited=visited,
+ )
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+
+ last_key = key
+ parent = list_or_dict
+
+ try:
+ if isinstance(list_or_dict, dict):
+ list_or_dict = list_or_dict[key]
+ else:
+ list_or_dict = list_or_dict[int(key)]
+ except (KeyError, IndexError, ValueError) as e:
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
+
+ visited += [key]
+ # final expansion of retrieved value
+ if expand and callable(list_or_dict):
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+ except KeyNotFoundError as e:
+ if default is None:
+ raise e
+ else:
+ list_or_dict = default
+ success = False
+
+ if not pass_success:
+ return list_or_dict
+ else:
+ return list_or_dict, success
+
+
+if __name__ == "__main__":
+ config = {"keya": "a",
+ "keyb": "b",
+ "keyc":
+ {"cc1": 1,
+ "cc2": 2,
+ }
+ }
+ from omegaconf import OmegaConf
+ config = OmegaConf.create(config)
+ print(config)
+ retrieve(config, "keya")
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fffe46ed9fab48066f574bf9e66514119fa9e122
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,38 @@
+blendmodes==2022
+transformers==4.25.1
+accelerate==0.12.0
+basicsr==1.4.2
+gfpgan==1.3.8
+gradio==3.16.2
+numpy==1.23.3
+Pillow==9.4.0
+realesrgan==0.3.0
+# !!!!option1: if you use huggingface, you should use the following lines
+# !!!!option2: if you use lightning ai, you should comment the following lines
+lightning-api-access
+# torch==1.13.1+cu117
+# --extra-index-url https://download.pytorch.org/whl/cu117
+# torchvision==0.14.1+cu117
+# --extra-index-url https://download.pytorch.org/whl/cu117
+omegaconf==2.2.3
+pytorch_lightning==1.7.6
+scikit-image==0.19.2
+fonts
+font-roboto
+timm==0.6.7
+piexif==1.1.3
+einops==0.4.1
+jsonmerge==1.8.0
+clean-fid==0.1.29
+resize-right==0.0.2
+torchdiffeq==0.2.3
+kornia==0.6.7
+lark==1.1.2
+inflection==0.5.1
+GitPython==3.1.27
+torchsde==0.2.5
+safetensors==0.2.7
+httpcore<=0.15
+fastapi==0.90.1
+git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b
+git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
\ No newline at end of file
diff --git a/requirements_bak.txt b/requirements_bak.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9a6ff64efe0b2549bc40bdc9af1d08edcf58035f
--- /dev/null
+++ b/requirements_bak.txt
@@ -0,0 +1,12 @@
+addict
+future
+lmdb
+opencv-python
+pyyaml
+requests
+scipy
+tb-nightly
+tqdm
+yapf
+lpips
+gdown
\ No newline at end of file
diff --git a/scripts/custom_code.py b/scripts/custom_code.py
new file mode 100644
index 0000000000000000000000000000000000000000..935c544e3e8b9a9a282108563d4e00074502829a
--- /dev/null
+++ b/scripts/custom_code.py
@@ -0,0 +1,41 @@
+import modules.scripts as scripts
+import gradio as gr
+
+from modules.processing import Processed
+from modules.shared import opts, cmd_opts, state
+
+class Script(scripts.Script):
+
+ def title(self):
+ return "Custom code"
+
+ def show(self, is_img2img):
+ return cmd_opts.allow_code
+
+ def ui(self, is_img2img):
+ code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code"))
+
+ return [code]
+
+
+ def run(self, p, code):
+ assert cmd_opts.allow_code, '--allow-code option must be enabled'
+
+ display_result_data = [[], -1, ""]
+
+ def display(imgs, s=display_result_data[1], i=display_result_data[2]):
+ display_result_data[0] = imgs
+ display_result_data[1] = s
+ display_result_data[2] = i
+
+ from types import ModuleType
+ compiled = compile(code, '', 'exec')
+ module = ModuleType("testmodule")
+ module.__dict__.update(globals())
+ module.p = p
+ module.display = display
+ exec(compiled, module.__dict__)
+
+ return Processed(p, *display_result_data)
+
+
\ No newline at end of file
diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py
new file mode 100644
index 0000000000000000000000000000000000000000..65b61533929a018f0cb97a89266154bf569cd40e
--- /dev/null
+++ b/scripts/img2imgalt.py
@@ -0,0 +1,216 @@
+from collections import namedtuple
+
+import numpy as np
+from tqdm import trange
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
+from modules.processing import Processed
+from modules.shared import opts, cmd_opts, state
+
+import torch
+import k_diffusion as K
+
+from PIL import Image
+from torch import autocast
+from einops import rearrange, repeat
+
+
+def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
+ x = p.init_latent
+
+ s_in = x.new_ones([x.shape[0]])
+ dnw = K.external.CompVisDenoiser(shared.sd_model)
+ sigmas = dnw.get_sigmas(steps).flip(0)
+
+ shared.state.sampling_steps = steps
+
+ for i in trange(1, len(sigmas)):
+ shared.state.sampling_step += 1
+
+ x_in = torch.cat([x] * 2)
+ sigma_in = torch.cat([sigmas[i] * s_in] * 2)
+ cond_in = torch.cat([uncond, cond])
+
+ image_conditioning = torch.cat([p.image_conditioning] * 2)
+ cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
+
+ c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
+ t = dnw.sigma_to_t(sigma_in)
+
+ eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
+ denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
+
+ denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
+
+ d = (x - denoised) / sigmas[i]
+ dt = sigmas[i] - sigmas[i - 1]
+
+ x = x + d * dt
+
+ sd_samplers_common.store_latent(x)
+
+ # This shouldn't be necessary, but solved some VRAM issues
+ del x_in, sigma_in, cond_in, c_out, c_in, t,
+ del eps, denoised_uncond, denoised_cond, denoised, d, dt
+
+ shared.state.nextjob()
+
+ return x / x.std()
+
+
+Cached = namedtuple("Cached", ["noise", "cfg_scale", "steps", "latent", "original_prompt", "original_negative_prompt", "sigma_adjustment"])
+
+
+# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
+def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
+ x = p.init_latent
+
+ s_in = x.new_ones([x.shape[0]])
+ dnw = K.external.CompVisDenoiser(shared.sd_model)
+ sigmas = dnw.get_sigmas(steps).flip(0)
+
+ shared.state.sampling_steps = steps
+
+ for i in trange(1, len(sigmas)):
+ shared.state.sampling_step += 1
+
+ x_in = torch.cat([x] * 2)
+ sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
+ cond_in = torch.cat([uncond, cond])
+
+ image_conditioning = torch.cat([p.image_conditioning] * 2)
+ cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
+
+ c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
+
+ if i == 1:
+ t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
+ else:
+ t = dnw.sigma_to_t(sigma_in)
+
+ eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
+ denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
+
+ denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale
+
+ if i == 1:
+ d = (x - denoised) / (2 * sigmas[i])
+ else:
+ d = (x - denoised) / sigmas[i - 1]
+
+ dt = sigmas[i] - sigmas[i - 1]
+ x = x + d * dt
+
+ sd_samplers_common.store_latent(x)
+
+ # This shouldn't be necessary, but solved some VRAM issues
+ del x_in, sigma_in, cond_in, c_out, c_in, t,
+ del eps, denoised_uncond, denoised_cond, denoised, d, dt
+
+ shared.state.nextjob()
+
+ return x / sigmas[-1]
+
+
+class Script(scripts.Script):
+ def __init__(self):
+ self.cache = None
+
+ def title(self):
+ return "img2img alternative test"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ info = gr.Markdown('''
+ * `CFG Scale` should be 2 or lower.
+ ''')
+
+ override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=self.elem_id("override_sampler"))
+
+ override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=self.elem_id("override_prompt"))
+ original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=self.elem_id("original_prompt"))
+ original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=self.elem_id("original_negative_prompt"))
+
+ override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=self.elem_id("override_steps"))
+ st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=self.elem_id("st"))
+
+ override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=self.elem_id("override_strength"))
+
+ cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id("cfg"))
+ randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id("randomness"))
+ sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment"))
+
+ return [
+ info,
+ override_sampler,
+ override_prompt, original_prompt, original_negative_prompt,
+ override_steps, st,
+ override_strength,
+ cfg, randomness, sigma_adjustment,
+ ]
+
+ def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
+ # Override
+ if override_sampler:
+ p.sampler_name = "Euler"
+ if override_prompt:
+ p.prompt = original_prompt
+ p.negative_prompt = original_negative_prompt
+ if override_steps:
+ p.steps = st
+ if override_strength:
+ p.denoising_strength = 1.0
+
+ def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
+ lat = (p.init_latent.cpu().numpy() * 10).astype(int)
+
+ same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
+ and self.cache.original_prompt == original_prompt \
+ and self.cache.original_negative_prompt == original_negative_prompt \
+ and self.cache.sigma_adjustment == sigma_adjustment
+ same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100
+
+ if same_everything:
+ rec_noise = self.cache.noise
+ else:
+ shared.state.job_count += 1
+ cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
+ uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt])
+ if sigma_adjustment:
+ rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st)
+ else:
+ rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
+ self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
+
+ rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
+
+ combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
+
+ sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
+
+ sigmas = sampler.model_wrap.get_sigmas(p.steps)
+
+ noise_dt = combined_noise - (p.init_latent / sigmas[0])
+
+ p.seed = p.seed + 1
+
+ return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
+
+ p.sample = sample_extra
+
+ p.extra_generation_params["Decode prompt"] = original_prompt
+ p.extra_generation_params["Decode negative prompt"] = original_negative_prompt
+ p.extra_generation_params["Decode CFG scale"] = cfg
+ p.extra_generation_params["Decode steps"] = st
+ p.extra_generation_params["Randomness"] = randomness
+ p.extra_generation_params["Sigma Adjustment"] = sigma_adjustment
+
+ processed = processing.process_images(p)
+
+ return processed
+
diff --git a/scripts/loopback.py b/scripts/loopback.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3ac98240ea022b028e056c21626b453ae19543d
--- /dev/null
+++ b/scripts/loopback.py
@@ -0,0 +1,98 @@
+import numpy as np
+from tqdm import trange
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import processing, shared, sd_samplers, images
+from modules.processing import Processed
+from modules.sd_samplers import samplers
+from modules.shared import opts, cmd_opts, state
+from modules import deepbooru
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Loopback"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
+ denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=self.elem_id("denoising_strength_change_factor"))
+ append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None")
+
+ return [loops, denoising_strength_change_factor, append_interrogation]
+
+ def run(self, p, loops, denoising_strength_change_factor, append_interrogation):
+ processing.fix_seed(p)
+ batch_count = p.n_iter
+ p.extra_generation_params = {
+ "Denoising strength change factor": denoising_strength_change_factor,
+ }
+
+ p.batch_size = 1
+ p.n_iter = 1
+
+ output_images, info = None, None
+ initial_seed = None
+ initial_info = None
+
+ grids = []
+ all_images = []
+ original_init_image = p.init_images
+ original_prompt = p.prompt
+ state.job_count = loops * batch_count
+
+ initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]
+
+ for n in range(batch_count):
+ history = []
+
+ # Reset to original init image at the start of each batch
+ p.init_images = original_init_image
+
+ for i in range(loops):
+ p.n_iter = 1
+ p.batch_size = 1
+ p.do_not_save_grid = True
+
+ if opts.img2img_color_correction:
+ p.color_corrections = initial_color_corrections
+
+ if append_interrogation != "None":
+ p.prompt = original_prompt + ", " if original_prompt != "" else ""
+ if append_interrogation == "CLIP":
+ p.prompt += shared.interrogator.interrogate(p.init_images[0])
+ elif append_interrogation == "DeepBooru":
+ p.prompt += deepbooru.model.tag(p.init_images[0])
+
+ state.job = f"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}"
+
+ processed = processing.process_images(p)
+
+ if initial_seed is None:
+ initial_seed = processed.seed
+ initial_info = processed.info
+
+ init_img = processed.images[0]
+
+ p.init_images = [init_img]
+ p.seed = processed.seed + 1
+ p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1)
+ history.append(processed.images[0])
+
+ grid = images.image_grid(history, rows=1)
+ if opts.grid_save:
+ images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
+
+ grids.append(grid)
+ all_images += history
+
+ if opts.return_grid:
+ all_images = grids + all_images
+
+ processed = Processed(p, all_images, initial_seed, initial_info)
+
+ return processed
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d80b46cd3263ef0905514a761bb473441d8a1e7
--- /dev/null
+++ b/scripts/outpainting_mk_2.py
@@ -0,0 +1,283 @@
+import math
+
+import numpy as np
+import skimage
+
+import modules.scripts as scripts
+import gradio as gr
+from PIL import Image, ImageDraw
+
+from modules import images, processing, devices
+from modules.processing import Processed, process_images
+from modules.shared import opts, cmd_opts, state
+
+
+# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
+def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
+ # helper fft routines that keep ortho normalization and auto-shift before and after fft
+ def _fft2(data):
+ if data.ndim > 2: # has channels
+ out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
+ for c in range(data.shape[2]):
+ c_data = data[:, :, c]
+ out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
+ out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
+ else: # one channel
+ out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
+ out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
+ out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
+
+ return out_fft
+
+ def _ifft2(data):
+ if data.ndim > 2: # has channels
+ out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
+ for c in range(data.shape[2]):
+ c_data = data[:, :, c]
+ out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
+ out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
+ else: # one channel
+ out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
+ out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
+ out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
+
+ return out_ifft
+
+ def _get_gaussian_window(width, height, std=3.14, mode=0):
+ window_scale_x = float(width / min(width, height))
+ window_scale_y = float(height / min(width, height))
+
+ window = np.zeros((width, height))
+ x = (np.arange(width) / width * 2. - 1.) * window_scale_x
+ for y in range(height):
+ fy = (y / height * 2. - 1.) * window_scale_y
+ if mode == 0:
+ window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std)
+ else:
+ window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14) # hey wait a minute that's not gaussian
+
+ return window
+
+ def _get_masked_window_rgb(np_mask_grey, hardness=1.):
+ np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
+ if hardness != 1.:
+ hardened = np_mask_grey[:] ** hardness
+ else:
+ hardened = np_mask_grey[:]
+ for c in range(3):
+ np_mask_rgb[:, :, c] = hardened[:]
+ return np_mask_rgb
+
+ width = _np_src_image.shape[0]
+ height = _np_src_image.shape[1]
+ num_channels = _np_src_image.shape[2]
+
+ np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
+ np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.)
+ img_mask = np_mask_grey > 1e-6
+ ref_mask = np_mask_grey < 1e-3
+
+ windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey))
+ windowed_image /= np.max(windowed_image)
+ windowed_image += np.average(_np_src_image) * np_mask_rgb # / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
+
+ src_fft = _fft2(windowed_image) # get feature statistics from masked src img
+ src_dist = np.absolute(src_fft)
+ src_phase = src_fft / src_dist
+
+ # create a generator with a static seed to make outpainting deterministic / only follow global seed
+ rng = np.random.default_rng(0)
+
+ noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
+ noise_rgb = rng.random((width, height, num_channels))
+ noise_grey = (np.sum(noise_rgb, axis=2) / 3.)
+ noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
+ for c in range(num_channels):
+ noise_rgb[:, :, c] += (1. - color_variation) * noise_grey
+
+ noise_fft = _fft2(noise_rgb)
+ for c in range(num_channels):
+ noise_fft[:, :, c] *= noise_window
+ noise_rgb = np.real(_ifft2(noise_fft))
+ shaped_noise_fft = _fft2(noise_rgb)
+ shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
+
+ brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now
+ contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
+
+ # scikit-image is used for histogram matching, very convenient!
+ shaped_noise = np.real(_ifft2(shaped_noise_fft))
+ shaped_noise -= np.min(shaped_noise)
+ shaped_noise /= np.max(shaped_noise)
+ shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1)
+ shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
+
+ matched_noise = shaped_noise[:]
+
+ return np.clip(matched_noise, 0., 1.)
+
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Outpainting mk2"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ if not is_img2img:
+ return None
+
+ info = gr.HTML("Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8
")
+
+ pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur"))
+ direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
+ noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q"))
+ color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation"))
+
+ return [info, pixels, mask_blur, direction, noise_q, color_variation]
+
+ def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation):
+ initial_seed_and_info = [None, None]
+
+ process_width = p.width
+ process_height = p.height
+
+ p.mask_blur = mask_blur*4
+ p.inpaint_full_res = False
+ p.inpainting_fill = 1
+ p.do_not_save_samples = True
+ p.do_not_save_grid = True
+
+ left = pixels if "left" in direction else 0
+ right = pixels if "right" in direction else 0
+ up = pixels if "up" in direction else 0
+ down = pixels if "down" in direction else 0
+
+ init_img = p.init_images[0]
+ target_w = math.ceil((init_img.width + left + right) / 64) * 64
+ target_h = math.ceil((init_img.height + up + down) / 64) * 64
+
+ if left > 0:
+ left = left * (target_w - init_img.width) // (left + right)
+
+ if right > 0:
+ right = target_w - init_img.width - left
+
+ if up > 0:
+ up = up * (target_h - init_img.height) // (up + down)
+
+ if down > 0:
+ down = target_h - init_img.height - up
+
+ def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
+ is_horiz = is_left or is_right
+ is_vert = is_top or is_bottom
+ pixels_horiz = expand_pixels if is_horiz else 0
+ pixels_vert = expand_pixels if is_vert else 0
+
+ images_to_process = []
+ output_images = []
+ for n in range(count):
+ res_w = init[n].width + pixels_horiz
+ res_h = init[n].height + pixels_vert
+ process_res_w = math.ceil(res_w / 64) * 64
+ process_res_h = math.ceil(res_h / 64) * 64
+
+ img = Image.new("RGB", (process_res_w, process_res_h))
+ img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
+ mask = Image.new("RGB", (process_res_w, process_res_h), "white")
+ draw = ImageDraw.Draw(mask)
+ draw.rectangle((
+ expand_pixels + mask_blur if is_left else 0,
+ expand_pixels + mask_blur if is_top else 0,
+ mask.width - expand_pixels - mask_blur if is_right else res_w,
+ mask.height - expand_pixels - mask_blur if is_bottom else res_h,
+ ), fill="black")
+
+ np_image = (np.asarray(img) / 255.0).astype(np.float64)
+ np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
+ noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
+ output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
+
+ target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
+ target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
+ p.width = target_width if is_horiz else img.width
+ p.height = target_height if is_vert else img.height
+
+ crop_region = (
+ 0 if is_left else output_images[n].width - target_width,
+ 0 if is_top else output_images[n].height - target_height,
+ target_width if is_left else output_images[n].width,
+ target_height if is_top else output_images[n].height,
+ )
+ mask = mask.crop(crop_region)
+ p.image_mask = mask
+
+ image_to_process = output_images[n].crop(crop_region)
+ images_to_process.append(image_to_process)
+
+ p.init_images = images_to_process
+
+ latent_mask = Image.new("RGB", (p.width, p.height), "white")
+ draw = ImageDraw.Draw(latent_mask)
+ draw.rectangle((
+ expand_pixels + mask_blur * 2 if is_left else 0,
+ expand_pixels + mask_blur * 2 if is_top else 0,
+ mask.width - expand_pixels - mask_blur * 2 if is_right else res_w,
+ mask.height - expand_pixels - mask_blur * 2 if is_bottom else res_h,
+ ), fill="black")
+ p.latent_mask = latent_mask
+
+ proc = process_images(p)
+
+ if initial_seed_and_info[0] is None:
+ initial_seed_and_info[0] = proc.seed
+ initial_seed_and_info[1] = proc.info
+
+ for n in range(count):
+ output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
+ output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
+
+ return output_images
+
+ batch_count = p.n_iter
+ batch_size = p.batch_size
+ p.n_iter = 1
+ state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
+ all_processed_images = []
+
+ for i in range(batch_count):
+ imgs = [init_img] * batch_size
+ state.job = f"Batch {i + 1} out of {batch_count}"
+
+ if left > 0:
+ imgs = expand(imgs, batch_size, left, is_left=True)
+ if right > 0:
+ imgs = expand(imgs, batch_size, right, is_right=True)
+ if up > 0:
+ imgs = expand(imgs, batch_size, up, is_top=True)
+ if down > 0:
+ imgs = expand(imgs, batch_size, down, is_bottom=True)
+
+ all_processed_images += imgs
+
+ all_images = all_processed_images
+
+ combined_grid_image = images.image_grid(all_processed_images)
+ unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
+ if opts.return_grid and not unwanted_grid_because_of_img_count:
+ all_images = [combined_grid_image] + all_processed_images
+
+ res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
+
+ if opts.samples_save:
+ for img in all_processed_images:
+ images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
+
+ if opts.grid_save and not unwanted_grid_because_of_img_count:
+ images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
+
+ return res
diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..d39f61c1073376eae210d955ac1e9eba836402da
--- /dev/null
+++ b/scripts/poor_mans_outpainting.py
@@ -0,0 +1,146 @@
+import math
+
+import modules.scripts as scripts
+import gradio as gr
+from PIL import Image, ImageDraw
+
+from modules import images, processing, devices
+from modules.processing import Processed, process_images
+from modules.shared import opts, cmd_opts, state
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Poor man's outpainting"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ if not is_img2img:
+ return None
+
+ pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur"))
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill"))
+ direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction"))
+
+ return [pixels, mask_blur, inpainting_fill, direction]
+
+ def run(self, p, pixels, mask_blur, inpainting_fill, direction):
+ initial_seed = None
+ initial_info = None
+
+ p.mask_blur = mask_blur * 2
+ p.inpainting_fill = inpainting_fill
+ p.inpaint_full_res = False
+
+ left = pixels if "left" in direction else 0
+ right = pixels if "right" in direction else 0
+ up = pixels if "up" in direction else 0
+ down = pixels if "down" in direction else 0
+
+ init_img = p.init_images[0]
+ target_w = math.ceil((init_img.width + left + right) / 64) * 64
+ target_h = math.ceil((init_img.height + up + down) / 64) * 64
+
+ if left > 0:
+ left = left * (target_w - init_img.width) // (left + right)
+ if right > 0:
+ right = target_w - init_img.width - left
+
+ if up > 0:
+ up = up * (target_h - init_img.height) // (up + down)
+
+ if down > 0:
+ down = target_h - init_img.height - up
+
+ img = Image.new("RGB", (target_w, target_h))
+ img.paste(init_img, (left, up))
+
+ mask = Image.new("L", (img.width, img.height), "white")
+ draw = ImageDraw.Draw(mask)
+ draw.rectangle((
+ left + (mask_blur * 2 if left > 0 else 0),
+ up + (mask_blur * 2 if up > 0 else 0),
+ mask.width - right - (mask_blur * 2 if right > 0 else 0),
+ mask.height - down - (mask_blur * 2 if down > 0 else 0)
+ ), fill="black")
+
+ latent_mask = Image.new("L", (img.width, img.height), "white")
+ latent_draw = ImageDraw.Draw(latent_mask)
+ latent_draw.rectangle((
+ left + (mask_blur//2 if left > 0 else 0),
+ up + (mask_blur//2 if up > 0 else 0),
+ mask.width - right - (mask_blur//2 if right > 0 else 0),
+ mask.height - down - (mask_blur//2 if down > 0 else 0)
+ ), fill="black")
+
+ devices.torch_gc()
+
+ grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
+ grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
+ grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
+
+ p.n_iter = 1
+ p.batch_size = 1
+ p.do_not_save_grid = True
+ p.do_not_save_samples = True
+
+ work = []
+ work_mask = []
+ work_latent_mask = []
+ work_results = []
+
+ for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles):
+ for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask):
+ x, w = tiledata[0:2]
+
+ if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:
+ continue
+
+ work.append(tiledata[2])
+ work_mask.append(tiledata_mask[2])
+ work_latent_mask.append(tiledata_latent_mask[2])
+
+ batch_count = len(work)
+ print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.")
+
+ state.job_count = batch_count
+
+ for i in range(batch_count):
+ p.init_images = [work[i]]
+ p.image_mask = work_mask[i]
+ p.latent_mask = work_latent_mask[i]
+
+ state.job = f"Batch {i + 1} out of {batch_count}"
+ processed = process_images(p)
+
+ if initial_seed is None:
+ initial_seed = processed.seed
+ initial_info = processed.info
+
+ p.seed = processed.seed + 1
+ work_results += processed.images
+
+
+ image_index = 0
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ x, w = tiledata[0:2]
+
+ if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:
+ continue
+
+ tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
+ image_index += 1
+
+ combined_image = images.combine_grid(grid)
+
+ if opts.samples_save:
+ images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info, p=p)
+
+ processed = Processed(p, [combined_image], initial_seed, initial_info)
+
+ return processed
+
diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e337ec41ffffe11fd88fced0ff4f6338d959571
--- /dev/null
+++ b/scripts/postprocessing_codeformer.py
@@ -0,0 +1,36 @@
+from PIL import Image
+import numpy as np
+
+from modules import scripts_postprocessing, codeformer_model
+import gradio as gr
+
+from modules.ui_components import FormRow
+
+
+class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
+ name = "CodeFormer"
+ order = 3000
+
+ def ui(self):
+ with FormRow():
+ codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility")
+ codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
+
+ return {
+ "codeformer_visibility": codeformer_visibility,
+ "codeformer_weight": codeformer_weight,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight):
+ if codeformer_visibility == 0:
+ return
+
+ restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
+ res = Image.fromarray(restored_img)
+
+ if codeformer_visibility < 1.0:
+ res = Image.blend(pp.image, res, codeformer_visibility)
+
+ pp.image = res
+ pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3)
+ pp.info["CodeFormer weight"] = round(codeformer_weight, 3)
diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f7c2baaa28333958818d332324b34bcb8bce3ca
--- /dev/null
+++ b/scripts/postprocessing_gfpgan.py
@@ -0,0 +1,33 @@
+from PIL import Image
+import numpy as np
+
+from modules import scripts_postprocessing, gfpgan_model
+import gradio as gr
+
+from modules.ui_components import FormRow
+
+
+class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
+ name = "GFPGAN"
+ order = 2000
+
+ def ui(self):
+ with FormRow():
+ gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility")
+
+ return {
+ "gfpgan_visibility": gfpgan_visibility,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility):
+ if gfpgan_visibility == 0:
+ return
+
+ restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
+ res = Image.fromarray(restored_img)
+
+ if gfpgan_visibility < 1.0:
+ res = Image.blend(pp.image, res, gfpgan_visibility)
+
+ pp.image = res
+ pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3)
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccec72fcbc72eeffbe24a659bf53ecba71162391
--- /dev/null
+++ b/scripts/postprocessing_upscale.py
@@ -0,0 +1,131 @@
+from PIL import Image
+import numpy as np
+
+from modules import scripts_postprocessing, shared
+import gradio as gr
+
+from modules.ui_components import FormRow
+
+
+upscale_cache = {}
+
+
+class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
+ name = "Upscale"
+ order = 1000
+
+ def ui(self):
+ selected_tab = gr.State(value=0)
+
+ with gr.Tabs(elem_id="extras_resize_mode"):
+ with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
+ upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
+
+ with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
+ with FormRow():
+ upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
+ upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
+ upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
+
+ with FormRow():
+ extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
+
+ with FormRow():
+ extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
+ extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
+
+ tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
+ tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
+
+ return {
+ "upscale_mode": selected_tab,
+ "upscale_by": upscaling_resize,
+ "upscale_to_width": upscaling_resize_w,
+ "upscale_to_height": upscaling_resize_h,
+ "upscale_crop": upscaling_crop,
+ "upscaler_1_name": extras_upscaler_1,
+ "upscaler_2_name": extras_upscaler_2,
+ "upscaler_2_visibility": extras_upscaler_2_visibility,
+ }
+
+ def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop):
+ if upscale_mode == 1:
+ upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height)
+ info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}"
+ else:
+ info["Postprocess upscale by"] = upscale_by
+
+ cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
+ cached_image = upscale_cache.pop(cache_key, None)
+
+ if cached_image is not None:
+ image = cached_image
+ else:
+ image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path)
+
+ upscale_cache[cache_key] = image
+ if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache:
+ upscale_cache.pop(next(iter(upscale_cache), None), None)
+
+ if upscale_mode == 1 and upscale_crop:
+ cropped = Image.new("RGB", (upscale_to_width, upscale_to_height))
+ cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2))
+ image = cropped
+ info["Postprocess crop to"] = f"{image.width}x{image.height}"
+
+ return image
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
+ if upscaler_1_name == "None":
+ upscaler_1_name = None
+
+ upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None)
+ assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}'
+
+ if not upscaler1:
+ return
+
+ if upscaler_2_name == "None":
+ upscaler_2_name = None
+
+ upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None)
+ assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}'
+
+ upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
+ pp.info[f"Postprocess upscaler"] = upscaler1.name
+
+ if upscaler2 and upscaler_2_visibility > 0:
+ second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
+ upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility)
+
+ pp.info[f"Postprocess upscaler 2"] = upscaler2.name
+
+ pp.image = upscaled_image
+
+ def image_changed(self):
+ upscale_cache.clear()
+
+
+class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
+ name = "Simple Upscale"
+ order = 900
+
+ def ui(self):
+ with FormRow():
+ upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
+ upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2)
+
+ return {
+ "upscale_by": upscale_by,
+ "upscaler_name": upscaler_name,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
+ if upscaler_name is None or upscaler_name == "None":
+ return
+
+ upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None)
+ assert upscaler1, f'could not find upscaler named {upscaler_name}'
+
+ pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
+ pp.info[f"Postprocess upscaler"] = upscaler1.name
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..51c70998866d4b0853a46e4de73d86c3d9ec9b93
--- /dev/null
+++ b/scripts/prompt_matrix.py
@@ -0,0 +1,111 @@
+import math
+from collections import namedtuple
+from copy import copy
+import random
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import images
+from modules.processing import process_images, Processed
+from modules.shared import opts, cmd_opts, state
+import modules.sd_samplers
+
+
+def draw_xy_grid(xs, ys, x_label, y_label, cell):
+ res = []
+
+ ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
+ hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
+
+ first_processed = None
+
+ state.job_count = len(xs) * len(ys)
+
+ for iy, y in enumerate(ys):
+ for ix, x in enumerate(xs):
+ state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
+
+ processed = cell(x, y)
+ if first_processed is None:
+ first_processed = processed
+
+ res.append(processed.images[0])
+
+ grid = images.image_grid(res, rows=len(ys))
+ grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
+
+ first_processed.images = [grid]
+
+ return first_processed
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Prompt matrix"
+
+ def ui(self, is_img2img):
+ gr.HTML('
')
+ with gr.Row():
+ with gr.Column():
+ put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
+ different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
+ with gr.Column():
+ prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive")
+ variations_delimiter = gr.Radio(["comma", "space"], label="Select joining char", elem_id=self.elem_id("variations_delimiter"), value="comma")
+ with gr.Column():
+ margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
+
+ return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size]
+
+ def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size):
+ modules.processing.fix_seed(p)
+ # Raise error if promp type is not positive or negative
+ if prompt_type not in ["positive", "negative"]:
+ raise ValueError(f"Unknown prompt type {prompt_type}")
+ # Raise error if variations delimiter is not comma or space
+ if variations_delimiter not in ["comma", "space"]:
+ raise ValueError(f"Unknown variations delimiter {variations_delimiter}")
+
+ prompt = p.prompt if prompt_type == "positive" else p.negative_prompt
+ original_prompt = prompt[0] if type(prompt) == list else prompt
+ positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
+
+ delimiter = ", " if variations_delimiter == "comma" else " "
+
+ all_prompts = []
+ prompt_matrix_parts = original_prompt.split("|")
+ combination_count = 2 ** (len(prompt_matrix_parts) - 1)
+ for combination_num in range(combination_count):
+ selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
+
+ if put_at_start:
+ selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
+ else:
+ selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
+
+ all_prompts.append(delimiter.join(selected_prompts))
+
+ p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
+ p.do_not_save_grid = True
+
+ print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
+
+ if prompt_type == "positive":
+ p.prompt = all_prompts
+ else:
+ p.negative_prompt = all_prompts
+ p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]
+ p.prompt_for_display = positive_prompt
+ processed = process_images(p)
+
+ grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
+ grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[1].height, prompt_matrix_parts, margin_size)
+ processed.images.insert(0, grid)
+ processed.index_of_first_image = 1
+ processed.infotexts.insert(0, processed.infotexts[0])
+
+ if opts.grid_save:
+ images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p)
+
+ return processed
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
new file mode 100644
index 0000000000000000000000000000000000000000..17c9c967ffeeb3f538c6f95d93ae79c32ca17828
--- /dev/null
+++ b/scripts/prompts_from_file.py
@@ -0,0 +1,177 @@
+import copy
+import math
+import os
+import random
+import sys
+import traceback
+import shlex
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import sd_samplers
+from modules.processing import Processed, process_images
+from PIL import Image
+from modules.shared import opts, cmd_opts, state
+
+
+def process_string_tag(tag):
+ return tag
+
+
+def process_int_tag(tag):
+ return int(tag)
+
+
+def process_float_tag(tag):
+ return float(tag)
+
+
+def process_boolean_tag(tag):
+ return True if (tag == "true") else False
+
+
+prompt_tags = {
+ "sd_model": None,
+ "outpath_samples": process_string_tag,
+ "outpath_grids": process_string_tag,
+ "prompt_for_display": process_string_tag,
+ "prompt": process_string_tag,
+ "negative_prompt": process_string_tag,
+ "styles": process_string_tag,
+ "seed": process_int_tag,
+ "subseed_strength": process_float_tag,
+ "subseed": process_int_tag,
+ "seed_resize_from_h": process_int_tag,
+ "seed_resize_from_w": process_int_tag,
+ "sampler_index": process_int_tag,
+ "sampler_name": process_string_tag,
+ "batch_size": process_int_tag,
+ "n_iter": process_int_tag,
+ "steps": process_int_tag,
+ "cfg_scale": process_float_tag,
+ "width": process_int_tag,
+ "height": process_int_tag,
+ "restore_faces": process_boolean_tag,
+ "tiling": process_boolean_tag,
+ "do_not_save_samples": process_boolean_tag,
+ "do_not_save_grid": process_boolean_tag
+}
+
+
+def cmdargs(line):
+ args = shlex.split(line)
+ pos = 0
+ res = {}
+
+ while pos < len(args):
+ arg = args[pos]
+
+ assert arg.startswith("--"), f'must start with "--": {arg}'
+ assert pos+1 < len(args), f'missing argument for command line option {arg}'
+
+ tag = arg[2:]
+
+ if tag == "prompt" or tag == "negative_prompt":
+ pos += 1
+ prompt = args[pos]
+ pos += 1
+ while pos < len(args) and not args[pos].startswith("--"):
+ prompt += " "
+ prompt += args[pos]
+ pos += 1
+ res[tag] = prompt
+ continue
+
+
+ func = prompt_tags.get(tag, None)
+ assert func, f'unknown commandline option: {arg}'
+
+ val = args[pos+1]
+ if tag == "sampler_name":
+ val = sd_samplers.samplers_map.get(val.lower(), None)
+
+ res[tag] = func(val)
+
+ pos += 2
+
+ return res
+
+
+def load_prompt_file(file):
+ if file is None:
+ lines = []
+ else:
+ lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")]
+
+ return None, "\n".join(lines), gr.update(lines=7)
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Prompts from file or textbox"
+
+ def ui(self, is_img2img):
+ checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
+ checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
+
+ prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
+ file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
+
+ file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt])
+
+ # We start at one line. When the text changes, we jump to seven lines, or two lines if no \n.
+ # We don't shrink back to 1, because that causes the control to ignore [enter], and it may
+ # be unclear to the user that shift-enter is needed.
+ prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt])
+ return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
+
+ def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
+ lines = [x.strip() for x in prompt_txt.splitlines()]
+ lines = [x for x in lines if len(x) > 0]
+
+ p.do_not_save_grid = True
+
+ job_count = 0
+ jobs = []
+
+ for line in lines:
+ if "--" in line:
+ try:
+ args = cmdargs(line)
+ except Exception:
+ print(f"Error parsing line {line} as commandline:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ args = {"prompt": line}
+ else:
+ args = {"prompt": line}
+
+ job_count += args.get("n_iter", p.n_iter)
+
+ jobs.append(args)
+
+ print(f"Will process {len(lines)} lines in {job_count} jobs.")
+ if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1:
+ p.seed = int(random.randrange(4294967294))
+
+ state.job_count = job_count
+
+ images = []
+ all_prompts = []
+ infotexts = []
+ for n, args in enumerate(jobs):
+ state.job = f"{state.job_no + 1} out of {state.job_count}"
+
+ copy_p = copy.copy(p)
+ for k, v in args.items():
+ setattr(copy_p, k, v)
+
+ proc = process_images(copy_p)
+ images += proc.images
+
+ if checkbox_iterate:
+ p.seed = p.seed + (p.batch_size * p.n_iter)
+ all_prompts += proc.all_prompts
+ infotexts += proc.infotexts
+
+ return Processed(p, images, p.seed, "", all_prompts=all_prompts, infotexts=infotexts)
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd64d7d385b34f960eebd7aec3233a84ea90609a
--- /dev/null
+++ b/scripts/sd_upscale.py
@@ -0,0 +1,101 @@
+import math
+
+import modules.scripts as scripts
+import gradio as gr
+from PIL import Image
+
+from modules import processing, shared, sd_samplers, images, devices
+from modules.processing import Processed
+from modules.shared import opts, cmd_opts, state
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "SD upscale"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ info = gr.HTML("Will upscale the image by the selected scale factor; use width and height sliders to set tile size
")
+ overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap"))
+ scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor"))
+ upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index"))
+
+ return [info, overlap, upscaler_index, scale_factor]
+
+ def run(self, p, _, overlap, upscaler_index, scale_factor):
+ if isinstance(upscaler_index, str):
+ upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower())
+ processing.fix_seed(p)
+ upscaler = shared.sd_upscalers[upscaler_index]
+
+ p.extra_generation_params["SD upscale overlap"] = overlap
+ p.extra_generation_params["SD upscale upscaler"] = upscaler.name
+
+ initial_info = None
+ seed = p.seed
+
+ init_img = p.init_images[0]
+ init_img = images.flatten(init_img, opts.img2img_background_color)
+
+ if upscaler.name != "None":
+ img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
+ else:
+ img = init_img
+
+ devices.torch_gc()
+
+ grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap)
+
+ batch_size = p.batch_size
+ upscale_count = p.n_iter
+ p.n_iter = 1
+ p.do_not_save_grid = True
+ p.do_not_save_samples = True
+
+ work = []
+
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ work.append(tiledata[2])
+
+ batch_count = math.ceil(len(work) / batch_size)
+ state.job_count = batch_count * upscale_count
+
+ print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.")
+
+ result_images = []
+ for n in range(upscale_count):
+ start_seed = seed + n
+ p.seed = start_seed
+
+ work_results = []
+ for i in range(batch_count):
+ p.batch_size = batch_size
+ p.init_images = work[i * batch_size:(i + 1) * batch_size]
+
+ state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}"
+ processed = processing.process_images(p)
+
+ if initial_info is None:
+ initial_info = processed.info
+
+ p.seed = processed.seed + 1
+ work_results += processed.images
+
+ image_index = 0
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
+ image_index += 1
+
+ combined_image = images.combine_grid(grid)
+ result_images.append(combined_image)
+
+ if opts.samples_save:
+ images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
+
+ processed = Processed(p, result_images, seed, initial_info)
+
+ return processed
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
new file mode 100644
index 0000000000000000000000000000000000000000..e457d53de2ade37a53cc200d6902323d3d6f25ce
--- /dev/null
+++ b/scripts/xyz_grid.py
@@ -0,0 +1,620 @@
+from collections import namedtuple
+from copy import copy
+from itertools import permutations, chain
+import random
+import csv
+from io import StringIO
+from PIL import Image
+import numpy as np
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
+from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
+from modules.shared import opts, cmd_opts, state
+import modules.shared as shared
+import modules.sd_samplers
+import modules.sd_models
+import modules.sd_vae
+import glob
+import os
+import re
+
+from modules.ui_components import ToolButton
+
+fill_values_symbol = "\U0001f4d2" # 📒
+
+AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])
+
+
+def apply_field(field):
+ def fun(p, x, xs):
+ setattr(p, field, x)
+
+ return fun
+
+
+def apply_prompt(p, x, xs):
+ if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
+ raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
+
+ p.prompt = p.prompt.replace(xs[0], x)
+ p.negative_prompt = p.negative_prompt.replace(xs[0], x)
+
+
+def apply_order(p, x, xs):
+ token_order = []
+
+ # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
+ for token in x:
+ token_order.append((p.prompt.find(token), token))
+
+ token_order.sort(key=lambda t: t[0])
+
+ prompt_parts = []
+
+ # Split the prompt up, taking out the tokens
+ for _, token in token_order:
+ n = p.prompt.find(token)
+ prompt_parts.append(p.prompt[0:n])
+ p.prompt = p.prompt[n + len(token):]
+
+ # Rebuild the prompt with the tokens in the order we want
+ prompt_tmp = ""
+ for idx, part in enumerate(prompt_parts):
+ prompt_tmp += part
+ prompt_tmp += x[idx]
+ p.prompt = prompt_tmp + p.prompt
+
+
+def apply_sampler(p, x, xs):
+ sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
+ if sampler_name is None:
+ raise RuntimeError(f"Unknown sampler: {x}")
+
+ p.sampler_name = sampler_name
+
+
+def confirm_samplers(p, xs):
+ for x in xs:
+ if x.lower() not in sd_samplers.samplers_map:
+ raise RuntimeError(f"Unknown sampler: {x}")
+
+
+def apply_checkpoint(p, x, xs):
+ info = modules.sd_models.get_closet_checkpoint_match(x)
+ if info is None:
+ raise RuntimeError(f"Unknown checkpoint: {x}")
+ modules.sd_models.reload_model_weights(shared.sd_model, info)
+
+
+def confirm_checkpoints(p, xs):
+ for x in xs:
+ if modules.sd_models.get_closet_checkpoint_match(x) is None:
+ raise RuntimeError(f"Unknown checkpoint: {x}")
+
+
+def apply_clip_skip(p, x, xs):
+ opts.data["CLIP_stop_at_last_layers"] = x
+
+
+def apply_upscale_latent_space(p, x, xs):
+ if x.lower().strip() != '0':
+ opts.data["use_scale_latent_for_hires_fix"] = True
+ else:
+ opts.data["use_scale_latent_for_hires_fix"] = False
+
+
+def find_vae(name: str):
+ if name.lower() in ['auto', 'automatic']:
+ return modules.sd_vae.unspecified
+ if name.lower() == 'none':
+ return None
+ else:
+ choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
+ if len(choices) == 0:
+ print(f"No VAE found for {name}; using automatic")
+ return modules.sd_vae.unspecified
+ else:
+ return modules.sd_vae.vae_dict[choices[0]]
+
+
+def apply_vae(p, x, xs):
+ modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
+
+
+def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
+ p.styles.extend(x.split(','))
+
+
+def format_value_add_label(p, opt, x):
+ if type(x) == float:
+ x = round(x, 8)
+
+ return f"{opt.label}: {x}"
+
+
+def format_value(p, opt, x):
+ if type(x) == float:
+ x = round(x, 8)
+ return x
+
+
+def format_value_join_list(p, opt, x):
+ return ", ".join(x)
+
+
+def do_nothing(p, x, xs):
+ pass
+
+
+def format_nothing(p, opt, x):
+ return ""
+
+
+def str_permutations(x):
+ """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
+ return x
+
+
+class AxisOption:
+ def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
+ self.label = label
+ self.type = type
+ self.apply = apply
+ self.format_value = format_value
+ self.confirm = confirm
+ self.cost = cost
+ self.choices = choices
+
+
+class AxisOptionImg2Img(AxisOption):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.is_img2img = True
+
+class AxisOptionTxt2Img(AxisOption):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.is_img2img = False
+
+
+axis_options = [
+ AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
+ AxisOption("Seed", int, apply_field("seed")),
+ AxisOption("Var. seed", int, apply_field("subseed")),
+ AxisOption("Var. strength", float, apply_field("subseed_strength")),
+ AxisOption("Steps", int, apply_field("steps")),
+ AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
+ AxisOption("CFG Scale", float, apply_field("cfg_scale")),
+ AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
+ AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
+ AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
+ AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
+ AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
+ AxisOption("Sigma Churn", float, apply_field("s_churn")),
+ AxisOption("Sigma min", float, apply_field("s_tmin")),
+ AxisOption("Sigma max", float, apply_field("s_tmax")),
+ AxisOption("Sigma noise", float, apply_field("s_noise")),
+ AxisOption("Eta", float, apply_field("eta")),
+ AxisOption("Clip skip", int, apply_clip_skip),
+ AxisOption("Denoising", float, apply_field("denoising_strength")),
+ AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
+ AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
+ AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
+ AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
+]
+
+
+def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):
+ hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
+ ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
+ title_texts = [[images.GridAnnotation(z)] for z in z_labels]
+
+ # Temporary list of all the images that are generated to be populated into the grid.
+ # Will be filled with empty images for any individual step that fails to process properly
+ image_cache = [None] * (len(xs) * len(ys) * len(zs))
+
+ processed_result = None
+ cell_mode = "P"
+ cell_size = (1, 1)
+
+ state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter
+
+ def process_cell(x, y, z, ix, iy, iz):
+ nonlocal image_cache, processed_result, cell_mode, cell_size
+
+ def index(ix, iy, iz):
+ return ix + iy * len(xs) + iz * len(xs) * len(ys)
+
+ state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}"
+
+ processed: Processed = cell(x, y, z)
+
+ try:
+ # this dereference will throw an exception if the image was not processed
+ # (this happens in cases such as if the user stops the process from the UI)
+ processed_image = processed.images[0]
+
+ if processed_result is None:
+ # Use our first valid processed result as a template container to hold our full results
+ processed_result = copy(processed)
+ cell_mode = processed_image.mode
+ cell_size = processed_image.size
+ processed_result.images = [Image.new(cell_mode, cell_size)]
+ processed_result.all_prompts = [processed.prompt]
+ processed_result.all_seeds = [processed.seed]
+ processed_result.infotexts = [processed.infotexts[0]]
+
+ image_cache[index(ix, iy, iz)] = processed_image
+ if include_lone_images:
+ processed_result.images.append(processed_image)
+ processed_result.all_prompts.append(processed.prompt)
+ processed_result.all_seeds.append(processed.seed)
+ processed_result.infotexts.append(processed.infotexts[0])
+ except:
+ image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size)
+
+ if first_axes_processed == 'x':
+ for ix, x in enumerate(xs):
+ if second_axes_processed == 'y':
+ for iy, y in enumerate(ys):
+ for iz, z in enumerate(zs):
+ process_cell(x, y, z, ix, iy, iz)
+ else:
+ for iz, z in enumerate(zs):
+ for iy, y in enumerate(ys):
+ process_cell(x, y, z, ix, iy, iz)
+ elif first_axes_processed == 'y':
+ for iy, y in enumerate(ys):
+ if second_axes_processed == 'x':
+ for ix, x in enumerate(xs):
+ for iz, z in enumerate(zs):
+ process_cell(x, y, z, ix, iy, iz)
+ else:
+ for iz, z in enumerate(zs):
+ for ix, x in enumerate(xs):
+ process_cell(x, y, z, ix, iy, iz)
+ elif first_axes_processed == 'z':
+ for iz, z in enumerate(zs):
+ if second_axes_processed == 'x':
+ for ix, x in enumerate(xs):
+ for iy, y in enumerate(ys):
+ process_cell(x, y, z, ix, iy, iz)
+ else:
+ for iy, y in enumerate(ys):
+ for ix, x in enumerate(xs):
+ process_cell(x, y, z, ix, iy, iz)
+
+ if not processed_result:
+ print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
+ return Processed(p, [])
+
+ sub_grids = [None] * len(zs)
+ for i in range(len(zs)):
+ start_index = i * len(xs) * len(ys)
+ end_index = start_index + len(xs) * len(ys)
+ grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
+ if draw_legend:
+ grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size)
+ sub_grids[i] = grid
+ if include_sub_grids and len(zs) > 1:
+ processed_result.images.insert(i+1, grid)
+
+ sub_grid_size = sub_grids[0].size
+ z_grid = images.image_grid(sub_grids, rows=1)
+ if draw_legend:
+ z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
+ processed_result.images[0] = z_grid
+
+ return processed_result, sub_grids
+
+
+class SharedSettingsStackHelper(object):
+ def __enter__(self):
+ self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
+ self.vae = opts.sd_vae
+
+ def __exit__(self, exc_type, exc_value, tb):
+ opts.data["sd_vae"] = self.vae
+ modules.sd_models.reload_model_weights()
+ modules.sd_vae.reload_vae_weights()
+
+ opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
+
+
+re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
+re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
+
+re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
+re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "X/Y/Z plot"
+
+ def ui(self, is_img2img):
+ self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]
+
+ with gr.Row():
+ with gr.Column(scale=19):
+ with gr.Row():
+ x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
+ x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
+ fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
+
+ with gr.Row():
+ y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
+ y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
+ fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
+
+ with gr.Row():
+ z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
+ z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
+ fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
+
+ with gr.Row(variant="compact", elem_id="axis_options"):
+ with gr.Column():
+ draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
+ no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
+ with gr.Column():
+ include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
+ include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
+ with gr.Column():
+ margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
+
+ with gr.Row(variant="compact", elem_id="swap_axes"):
+ swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
+ swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
+ swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
+
+ def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values):
+ return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values
+
+ xy_swap_args = [x_type, x_values, y_type, y_values]
+ swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
+ yz_swap_args = [y_type, y_values, z_type, z_values]
+ swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
+ xz_swap_args = [x_type, x_values, z_type, z_values]
+ swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
+
+ def fill(x_type):
+ axis = self.current_axis_options[x_type]
+ return ", ".join(axis.choices()) if axis.choices else gr.update()
+
+ fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
+ fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
+ fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values])
+
+ def select_axis(x_type):
+ return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
+
+ x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
+ y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
+ z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
+
+ self.infotext_fields = (
+ (x_type, "X Type"),
+ (x_values, "X Values"),
+ (y_type, "Y Type"),
+ (y_values, "Y Values"),
+ (z_type, "Z Type"),
+ (z_values, "Z Values"),
+ )
+
+ return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
+
+ def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
+ if not no_fixed_seeds:
+ modules.processing.fix_seed(p)
+
+ if not opts.return_grid:
+ p.batch_size = 1
+
+ def process_axis(opt, vals):
+ if opt.label == 'Nothing':
+ return [0]
+
+ valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
+
+ if opt.type == int:
+ valslist_ext = []
+
+ for val in valslist:
+ m = re_range.fullmatch(val)
+ mc = re_range_count.fullmatch(val)
+ if m is not None:
+ start = int(m.group(1))
+ end = int(m.group(2))+1
+ step = int(m.group(3)) if m.group(3) is not None else 1
+
+ valslist_ext += list(range(start, end, step))
+ elif mc is not None:
+ start = int(mc.group(1))
+ end = int(mc.group(2))
+ num = int(mc.group(3)) if mc.group(3) is not None else 1
+
+ valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
+ else:
+ valslist_ext.append(val)
+
+ valslist = valslist_ext
+ elif opt.type == float:
+ valslist_ext = []
+
+ for val in valslist:
+ m = re_range_float.fullmatch(val)
+ mc = re_range_count_float.fullmatch(val)
+ if m is not None:
+ start = float(m.group(1))
+ end = float(m.group(2))
+ step = float(m.group(3)) if m.group(3) is not None else 1
+
+ valslist_ext += np.arange(start, end + step, step).tolist()
+ elif mc is not None:
+ start = float(mc.group(1))
+ end = float(mc.group(2))
+ num = int(mc.group(3)) if mc.group(3) is not None else 1
+
+ valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
+ else:
+ valslist_ext.append(val)
+
+ valslist = valslist_ext
+ elif opt.type == str_permutations:
+ valslist = list(permutations(valslist))
+
+ valslist = [opt.type(x) for x in valslist]
+
+ # Confirm options are valid before starting
+ if opt.confirm:
+ opt.confirm(p, valslist)
+
+ return valslist
+
+ x_opt = self.current_axis_options[x_type]
+ xs = process_axis(x_opt, x_values)
+
+ y_opt = self.current_axis_options[y_type]
+ ys = process_axis(y_opt, y_values)
+
+ z_opt = self.current_axis_options[z_type]
+ zs = process_axis(z_opt, z_values)
+
+ def fix_axis_seeds(axis_opt, axis_list):
+ if axis_opt.label in ['Seed', 'Var. seed']:
+ return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
+ else:
+ return axis_list
+
+ if not no_fixed_seeds:
+ xs = fix_axis_seeds(x_opt, xs)
+ ys = fix_axis_seeds(y_opt, ys)
+ zs = fix_axis_seeds(z_opt, zs)
+
+ if x_opt.label == 'Steps':
+ total_steps = sum(xs) * len(ys) * len(zs)
+ elif y_opt.label == 'Steps':
+ total_steps = sum(ys) * len(xs) * len(zs)
+ elif z_opt.label == 'Steps':
+ total_steps = sum(zs) * len(xs) * len(ys)
+ else:
+ total_steps = p.steps * len(xs) * len(ys) * len(zs)
+
+ if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
+ if x_opt.label == "Hires steps":
+ total_steps += sum(xs) * len(ys) * len(zs)
+ elif y_opt.label == "Hires steps":
+ total_steps += sum(ys) * len(xs) * len(zs)
+ elif z_opt.label == "Hires steps":
+ total_steps += sum(zs) * len(xs) * len(ys)
+ elif p.hr_second_pass_steps:
+ total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)
+ else:
+ total_steps *= 2
+
+ total_steps *= p.n_iter
+
+ image_cell_count = p.n_iter * p.batch_size
+ cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""
+ plural_s = 's' if len(zs) > 1 else ''
+ print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")
+ shared.total_tqdm.updateTotal(total_steps)
+
+ grid_infotext = [None]
+
+ state.xyz_plot_x = AxisInfo(x_opt, xs)
+ state.xyz_plot_y = AxisInfo(y_opt, ys)
+ state.xyz_plot_z = AxisInfo(z_opt, zs)
+
+ # If one of the axes is very slow to change between (like SD model
+ # checkpoint), then make sure it is in the outer iteration of the nested
+ # `for` loop.
+ first_axes_processed = 'x'
+ second_axes_processed = 'y'
+ if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
+ first_axes_processed = 'x'
+ if y_opt.cost > z_opt.cost:
+ second_axes_processed = 'y'
+ else:
+ second_axes_processed = 'z'
+ elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost:
+ first_axes_processed = 'y'
+ if x_opt.cost > z_opt.cost:
+ second_axes_processed = 'x'
+ else:
+ second_axes_processed = 'z'
+ elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost:
+ first_axes_processed = 'z'
+ if x_opt.cost > y_opt.cost:
+ second_axes_processed = 'x'
+ else:
+ second_axes_processed = 'y'
+
+ def cell(x, y, z):
+ if shared.state.interrupted:
+ return Processed(p, [], p.seed, "")
+
+ pc = copy(p)
+ pc.styles = pc.styles[:]
+ x_opt.apply(pc, x, xs)
+ y_opt.apply(pc, y, ys)
+ z_opt.apply(pc, z, zs)
+
+ res = process_images(pc)
+
+ if grid_infotext[0] is None:
+ pc.extra_generation_params = copy(pc.extra_generation_params)
+ pc.extra_generation_params['Script'] = self.title()
+
+ if x_opt.label != 'Nothing':
+ pc.extra_generation_params["X Type"] = x_opt.label
+ pc.extra_generation_params["X Values"] = x_values
+ if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
+ pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs])
+
+ if y_opt.label != 'Nothing':
+ pc.extra_generation_params["Y Type"] = y_opt.label
+ pc.extra_generation_params["Y Values"] = y_values
+ if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
+ pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
+
+ if z_opt.label != 'Nothing':
+ pc.extra_generation_params["Z Type"] = z_opt.label
+ pc.extra_generation_params["Z Values"] = z_values
+ if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
+ pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs])
+
+ grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
+
+ return res
+
+ with SharedSettingsStackHelper():
+ processed, sub_grids = draw_xyz_grid(
+ p,
+ xs=xs,
+ ys=ys,
+ zs=zs,
+ x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
+ y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
+ z_labels=[z_opt.format_value(p, z_opt, z) for z in zs],
+ cell=cell,
+ draw_legend=draw_legend,
+ include_lone_images=include_lone_images,
+ include_sub_grids=include_sub_grids,
+ first_axes_processed=first_axes_processed,
+ second_axes_processed=second_axes_processed,
+ margin_size=margin_size
+ )
+
+ if opts.grid_save and len(sub_grids) > 1:
+ for sub_grid in sub_grids:
+ images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+
+ if opts.grid_save:
+ images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+
+ return processed
diff --git a/sd/.gitattributes b/sd/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..c7d9f3332a950355d5a77d85000f05e6f45435ea
--- /dev/null
+++ b/sd/.gitattributes
@@ -0,0 +1,34 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/sd/README.md b/sd/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..154df8298fab5ecf322016157858e08cd1bccbe1
--- /dev/null
+++ b/sd/README.md
@@ -0,0 +1,3 @@
+---
+license: apache-2.0
+---
diff --git a/webui-user.bat b/webui-user.bat
new file mode 100644
index 0000000000000000000000000000000000000000..e5a257bef06f5bfcaff1c8b33c64a767eb8b3fe5
--- /dev/null
+++ b/webui-user.bat
@@ -0,0 +1,8 @@
+@echo off
+
+set PYTHON=
+set GIT=
+set VENV_DIR=
+set COMMANDLINE_ARGS=
+
+call webui.bat
diff --git a/webui-user.sh b/webui-user.sh
new file mode 100644
index 0000000000000000000000000000000000000000..bfa53cb7c67083ec0a01bfa420269af4d85c6c94
--- /dev/null
+++ b/webui-user.sh
@@ -0,0 +1,46 @@
+#!/bin/bash
+#########################################################
+# Uncomment and change the variables below to your need:#
+#########################################################
+
+# Install directory without trailing slash
+#install_dir="/home/$(whoami)"
+
+# Name of the subdirectory
+#clone_dir="stable-diffusion-webui"
+
+# Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention"
+#export COMMANDLINE_ARGS=""
+
+# python3 executable
+#python_cmd="python3"
+
+# git executable
+#export GIT="git"
+
+# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
+#venv_dir="venv"
+
+# script to launch to start the app
+#export LAUNCH_SCRIPT="launch.py"
+
+# install command for torch
+#export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113"
+
+# Requirements file to use for stable-diffusion-webui
+#export REQS_FILE="requirements_versions.txt"
+
+# Fixed git repos
+#export K_DIFFUSION_PACKAGE=""
+#export GFPGAN_PACKAGE=""
+
+# Fixed git commits
+#export STABLE_DIFFUSION_COMMIT_HASH=""
+#export TAMING_TRANSFORMERS_COMMIT_HASH=""
+#export CODEFORMER_COMMIT_HASH=""
+#export BLIP_COMMIT_HASH=""
+
+# Uncomment to enable accelerated launch
+#export ACCELERATE="True"
+
+###########################################
diff --git a/webui.sh b/webui.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8cdad22d310fed20f229b09d7a3160aeb1731a85
--- /dev/null
+++ b/webui.sh
@@ -0,0 +1,186 @@
+#!/usr/bin/env bash
+#################################################
+# Please do not make any changes to this file, #
+# change the variables in webui-user.sh instead #
+#################################################
+
+# If run from macOS, load defaults from webui-macos-env.sh
+if [[ "$OSTYPE" == "darwin"* ]]; then
+ if [[ -f webui-macos-env.sh ]]
+ then
+ source ./webui-macos-env.sh
+ fi
+fi
+
+# Read variables from webui-user.sh
+# shellcheck source=/dev/null
+if [[ -f webui-user.sh ]]
+then
+ source ./webui-user.sh
+fi
+
+# Set defaults
+# Install directory without trailing slash
+if [[ -z "${install_dir}" ]]
+then
+ install_dir="/home/$(whoami)"
+fi
+
+# Name of the subdirectory (defaults to stable-diffusion-webui)
+if [[ -z "${clone_dir}" ]]
+then
+ clone_dir="stable-diffusion-webui"
+fi
+
+# python3 executable
+if [[ -z "${python_cmd}" ]]
+then
+ python_cmd="python3"
+fi
+
+# git executable
+if [[ -z "${GIT}" ]]
+then
+ export GIT="git"
+fi
+
+# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
+if [[ -z "${venv_dir}" ]]
+then
+ venv_dir="venv"
+fi
+
+if [[ -z "${LAUNCH_SCRIPT}" ]]
+then
+ LAUNCH_SCRIPT="launch.py"
+fi
+
+# this script cannot be run as root by default
+can_run_as_root=0
+
+# read any command line flags to the webui.sh script
+while getopts "f" flag > /dev/null 2>&1
+do
+ case ${flag} in
+ f) can_run_as_root=1;;
+ *) break;;
+ esac
+done
+
+# Disable sentry logging
+export ERROR_REPORTING=FALSE
+
+# Do not reinstall existing pip packages on Debian/Ubuntu
+export PIP_IGNORE_INSTALLED=0
+
+# Pretty print
+delimiter="################################################################"
+
+printf "\n%s\n" "${delimiter}"
+printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n"
+printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m"
+printf "\n%s\n" "${delimiter}"
+
+# Do not run as root
+if [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]]
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+else
+ printf "\n%s\n" "${delimiter}"
+ printf "Running on \e[1m\e[32m%s\e[0m user" "$(whoami)"
+ printf "\n%s\n" "${delimiter}"
+fi
+
+if [[ -d .git ]]
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "Repo already cloned, using it as install directory"
+ printf "\n%s\n" "${delimiter}"
+ install_dir="${PWD}/../"
+ clone_dir="${PWD##*/}"
+fi
+
+# Check prerequisites
+gpu_info=$(lspci 2>/dev/null | grep VGA)
+case "$gpu_info" in
+ *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
+ ;;
+ *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
+ printf "\n%s\n" "${delimiter}"
+ printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
+ printf "\n%s\n" "${delimiter}"
+ ;;
+ *)
+ ;;
+esac
+if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+then
+ export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
+fi
+
+for preq in "${GIT}" "${python_cmd}"
+do
+ if ! hash "${preq}" &>/dev/null
+ then
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: %s is not installed, aborting...\e[0m" "${preq}"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+ fi
+done
+
+if ! "${python_cmd}" -c "import venv" &>/dev/null
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+fi
+
+cd "${install_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/, aborting...\e[0m" "${install_dir}"; exit 1; }
+if [[ -d "${clone_dir}" ]]
+then
+ cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
+else
+ printf "\n%s\n" "${delimiter}"
+ printf "Clone stable-diffusion-webui"
+ printf "\n%s\n" "${delimiter}"
+ "${GIT}" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git "${clone_dir}"
+ cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
+fi
+
+printf "\n%s\n" "${delimiter}"
+printf "Create and activate python venv"
+printf "\n%s\n" "${delimiter}"
+cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
+if [[ ! -d "${venv_dir}" ]]
+then
+ "${python_cmd}" -m venv "${venv_dir}"
+ first_launch=1
+fi
+# shellcheck source=/dev/null
+if [[ -f "${venv_dir}"/bin/activate ]]
+then
+ source "${venv_dir}"/bin/activate
+else
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+fi
+
+if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "Accelerating launch.py..."
+ printf "\n%s\n" "${delimiter}"
+ exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
+else
+ printf "\n%s\n" "${delimiter}"
+ printf "Launching launch.py..."
+ printf "\n%s\n" "${delimiter}"
+ exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+fi
diff --git a/webui_v2.bat b/webui_v2.bat
new file mode 100644
index 0000000000000000000000000000000000000000..2b7fad92f440a60160c009803795ce2d3cd2ef15
--- /dev/null
+++ b/webui_v2.bat
@@ -0,0 +1,85 @@
+@echo off
+
+if not defined PYTHON (set PYTHON=python)
+if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
+
+
+set ERROR_REPORTING=FALSE
+
+mkdir tmp 2>NUL
+
+%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :check_pip
+echo Couldn't launch python
+goto :show_stdout_stderr
+
+:check_pip
+%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :start_venv
+if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
+%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :start_venv
+echo Couldn't install pip
+goto :show_stdout_stderr
+
+:start_venv
+if ["%VENV_DIR%"] == ["-"] goto :skip_venv
+if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
+
+dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :activate_venv
+
+for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
+echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
+%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :activate_venv
+echo Unable to create venv in directory "%VENV_DIR%"
+goto :show_stdout_stderr
+
+:activate_venv
+set PYTHON="%VENV_DIR%\Scripts\Python.exe"
+echo venv %PYTHON%
+
+:skip_venv
+if [%ACCELERATE%] == ["True"] goto :accelerate
+goto :launch
+
+:accelerate
+echo Checking for accelerate
+set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
+if EXIST %ACCELERATE% goto :accelerate_launch
+
+:launch
+%PYTHON% webui_v2.py %*
+pause
+exit /b
+
+:accelerate_launch
+echo Accelerating
+%ACCELERATE% launch --num_cpu_threads_per_process=6 webui_v2.py
+pause
+exit /b
+
+:show_stdout_stderr
+
+echo.
+echo exit code: %errorlevel%
+
+for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
+if %size% equ 0 goto :show_stderr
+echo.
+echo stdout:
+type tmp\stdout.txt
+
+:show_stderr
+for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
+if %size% equ 0 goto :show_stderr
+echo.
+echo stderr:
+type tmp\stderr.txt
+
+:endofscript
+
+echo.
+echo Launch unsuccessful. Exiting.
+pause