diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..342f2fb97c2320fd6d9246e1ae7d74f510d73672 --- /dev/null +++ b/app.py @@ -0,0 +1,105 @@ +import streamlit as st +import numpy as np +import os +import pathlib +from inference import infer, InferenceModel + +# ----------------------------------------------------------------------------- +# class SatvisionDemoApp +# +# Directory Structure: base-directory/MOD09GA/year +# MOD09GQ/year +# MYD09GA/year +# MYD09GQ/year +# +# ----------------------------------------------------------------------------- +class SatvisionDemoApp: + + # ------------------------------------------------------------------------- + # __init__ + # ------------------------------------------------------------------------- + def __init__(self): + + self.thumbnail_dir = pathlib.Path('data/thumbnails') + self.image_dir = pathlib.Path('data/images') + print(self.thumbnail_dir) + self.thumbnail_files = sorted(list(self.thumbnail_dir.glob('sv-*.png'))) + self.image_files = sorted(list(self.image_dir.glob('sv-*.npy'))) + print(list(self.image_files)) + self.thumbnail_names = [str(tn_path.name) for tn_path in self.thumbnail_files] + print(self.thumbnail_names) + + self.inferenceModel = InferenceModel() + + # ------------------------------------------------------------------------- + # render_sidebar + # ------------------------------------------------------------------------- + def render_sidebar(self): + + st.sidebar.header("Select an Image") + + for index, thumbnail in enumerate(self.thumbnail_names): + + thumbnail_path = self.thumbnail_dir / thumbnail + + # thumbnail_arr = np.load(thumbnail_path) + print(str(thumbnail_path)) + + st.sidebar.image(str(thumbnail_path), use_column_width=True, caption=thumbnail) + + # ------------------------------------------------------------------------- + # render_main_app + # ------------------------------------------------------------------------- + def render_main_app(self): + + st.title("Satvision-Base Demo") + + st.header("Image Reconstruction Process") + selected_image_index = st.sidebar.selectbox( + "Select an Image", + self.thumbnail_names) + print(selected_image_index) + + selected_image = self.load_selected_image(selected_image_index) + + image, masked_input, output = self.inferenceModel.infer(selected_image) + + col1, col2, col3 = st.columns(3, gap="large") + + # Display the selected image with a title three times side-by-side + + with col1: + st.image(image, use_column_width=True, caption="Input") + + with col2: + st.image(masked_input, use_column_width=True, caption="Input Masked") + + with col3: + st.image(output, use_column_width=True, caption="Reconstruction") + + # ------------------------------------------------------------------------- + # load_selected_image + # ------------------------------------------------------------------------- + def load_selected_image(self, image_name): + + # Load the selected image using NumPy (replace this with your image loading code) + image_name = image_name.replace('.png', '.npy') + + image = np.load(self.image_dir / image_name) + image = np.moveaxis(image, 0, 2) + return image + +# ----------------------------------------------------------------------------- +# main +# ----------------------------------------------------------------------------- +def main(): + + app = SatvisionDemoApp() + + app.render_main_app() + + app.render_sidebar() + +if __name__ == "__main__": + + main() \ No newline at end of file diff --git a/ckpt_epoch_800.pth b/ckpt_epoch_800.pth new file mode 100644 index 0000000000000000000000000000000000000000..b9bfc25d94b3d8feb222ee478a18d5ab2f2d44fb --- /dev/null +++ b/ckpt_epoch_800.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56efed01b695ccccea2a30bf435978ffd867bc08d618b1e9be73d4849bfc007f +size 1136978755 diff --git a/data/images/sv-demo-mod09ga-11.npy b/data/images/sv-demo-mod09ga-11.npy new file mode 100644 index 0000000000000000000000000000000000000000..2f3ef672dd4e9a9d399471bd85b87efadb8425d8 --- /dev/null +++ b/data/images/sv-demo-mod09ga-11.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:357b9efeda497e81aa0f825c056f2c58b9066ec7208da0e28ba22f57217a165c +size 1032320 diff --git a/data/images/sv-demo-mod09ga-12.npy b/data/images/sv-demo-mod09ga-12.npy new file mode 100644 index 0000000000000000000000000000000000000000..921c102fc7daeb87c37f436625e0ac7b46fcc70a --- /dev/null +++ b/data/images/sv-demo-mod09ga-12.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1e4eb11effc171010ad5397121da3ae546aa5c9167a75627bd34820b08447d64 +size 1032320 diff --git a/data/images/sv-demo-mod09ga-6.npy b/data/images/sv-demo-mod09ga-6.npy new file mode 100644 index 0000000000000000000000000000000000000000..58966242a9b23b29995dff3b5e6ac09b12a0f54b --- /dev/null +++ b/data/images/sv-demo-mod09ga-6.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:814c059fbcae6269c79b477c758fa00cfaad23fde73f2b4efd97f6c4a940d40a +size 1032320 diff --git a/data/images/sv-demo-mod09ga-7.npy b/data/images/sv-demo-mod09ga-7.npy new file mode 100644 index 0000000000000000000000000000000000000000..db8b53bfe5ac914d60ebd9fabb4e095a76067e03 --- /dev/null +++ b/data/images/sv-demo-mod09ga-7.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4159969faa0b6123f5918e39e28ce24f8fc735c28fecfb9943dd1bc03b74146 +size 1032320 diff --git a/data/thumbnails/sv-demo-mod09ga-11.png b/data/thumbnails/sv-demo-mod09ga-11.png new file mode 100644 index 0000000000000000000000000000000000000000..ae054f29d5d8bd573fb71734fcae369bd7d6070e Binary files /dev/null and b/data/thumbnails/sv-demo-mod09ga-11.png differ diff --git a/data/thumbnails/sv-demo-mod09ga-12.png b/data/thumbnails/sv-demo-mod09ga-12.png new file mode 100644 index 0000000000000000000000000000000000000000..e7883bec9a585f2d8e41b1fc33f0ec4d8944771e Binary files /dev/null and b/data/thumbnails/sv-demo-mod09ga-12.png differ diff --git a/data/thumbnails/sv-demo-mod09ga-6.png b/data/thumbnails/sv-demo-mod09ga-6.png new file mode 100644 index 0000000000000000000000000000000000000000..7d15e72726e0ee7b919d05a5bf7326cd9b07507e Binary files /dev/null and b/data/thumbnails/sv-demo-mod09ga-6.png differ diff --git a/data/thumbnails/sv-demo-mod09ga-7.png b/data/thumbnails/sv-demo-mod09ga-7.png new file mode 100644 index 0000000000000000000000000000000000000000..449605ac98aec9d574848c6013e0735a15fdfc0e Binary files /dev/null and b/data/thumbnails/sv-demo-mod09ga-7.png differ diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..3f89bcde1b0f55129f37e2b43ce9cb731f78656c --- /dev/null +++ b/inference.py @@ -0,0 +1,227 @@ +import numpy as np + +import torch +import joblib +import numpy as np + +import torchvision.transforms as T +import sys + +sys.path.append('pytorch-caney') +# from pytorch_caney.models.mim.mim import build_mim_model + + +class Transform: + """ + torchvision transform which transforms the input imagery into + addition to generating a MiM mask + """ + + def __init__(self, config): + + self.transform_img = \ + T.Compose([ + T.ToTensor(), + T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), + ]) + + model_patch_size = config.MODEL.SWINV2.PATCH_SIZE + + self.mask_generator = SimmimMaskGenerator( + input_size=config.DATA.IMG_SIZE, + mask_patch_size=config.DATA.MASK_PATCH_SIZE, + model_patch_size=model_patch_size, + mask_ratio=config.DATA.MASK_RATIO, + ) + + def __call__(self, img): + + img = self.transform_img(img) + mask = self.mask_generator() + + return img, mask + + +class SimmimMaskGenerator: + """ + Generates the masks for masked-image-modeling + """ + def __init__(self, + input_size=192, + mask_patch_size=32, + model_patch_size=4, + mask_ratio=0.6): + self.input_size = input_size + self.mask_patch_size = mask_patch_size + self.model_patch_size = model_patch_size + self.mask_ratio = mask_ratio + + assert self.input_size % self.mask_patch_size == 0 + assert self.mask_patch_size % self.model_patch_size == 0 + + self.rand_size = self.input_size // self.mask_patch_size + self.scale = self.mask_patch_size // self.model_patch_size + + self.token_count = self.rand_size ** 2 + self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) + + def __call__(self): + mask = self.make_simmim_mask(self.token_count, self.mask_count, + self.rand_size, self.scale) + mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) + return mask + + @staticmethod + def make_simmim_mask(token_count, mask_count, rand_size, scale): + """JIT-compiled random mask generation + + Args: + token_count + mask_count + rand_size + scale + + Returns: + mask + """ + mask_idx = np.random.permutation(token_count)[:mask_count] + mask = np.zeros(token_count, dtype=np.int64) + mask[mask_idx] = 1 + mask = mask.reshape((rand_size, rand_size)) + return mask + + +class InferenceModel(object): + + def __init__(self): + self.checkpoint_path = 'ckpt_epoch_800.pth' + self.config_path = 'simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm.config.sav' + self.architecture_path = 'model.sav' + + self.config = joblib.load(self.config_path) + self.model = joblib.load(self.architecture_path) + self.load_checkpoint() + + self.transform = Transform(self.config) + + + def load_checkpoint(self): + + + checkpoint = torch.load(self.checkpoint_path, map_location='cpu') + + # re-map keys due to name change (only for loading provided models) + rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k] + + for k in rpe_mlp_keys: + + checkpoint['model'][k.replace( + 'rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k) + + msg = self.model.load_state_dict(checkpoint['model'], strict=False) + + print(msg) + + del checkpoint + + torch.cuda.empty_cache() + + @staticmethod + def minmax_norm(img_arr): + arr_min = img_arr.min() + arr_max = img_arr.max() + img_arr_scaled = (img_arr - arr_min) / (arr_max - arr_min) + img_arr_scaled = img_arr_scaled * 255 + img_arr_scaled = img_arr_scaled.astype(np.uint8) + return img_arr_scaled + + # ------------------------------------------------------------------------- + # load_selected_image + # ------------------------------------------------------------------------- + def preprocess(self, image): + + image, mask = self.transform(image) + + image = image.unsqueeze(0) + + mask = torch.tensor(mask).unsqueeze(0) + + print(image.size()) + print(mask.shape) + + return image, mask + + # ------------------------------------------------------------------------- + # load_selected_image + # ------------------------------------------------------------------------- + def predict(self, image, mask): + + with torch.no_grad(): + + logits = self.model.encoder(image, mask) + + image_recon = self.model.decoder(logits) + + image_recon = image_recon.numpy()[0, :, :, :] + + return image_recon + + # ------------------------------------------------------------------------- + # load_selected_image + # ------------------------------------------------------------------------- + @staticmethod + def process_mask(mask): + mask = mask.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous() + mask = mask[0, 0, :, :] + mask = np.stack([mask, mask, mask], axis=-1) + return mask + + # ------------------------------------------------------------------------- + # load_selected_image + # ------------------------------------------------------------------------- + def infer(self, image): + + image, mask = self.preprocess(image) + + img_recon = self.predict(image, mask) + + mask = self.process_mask(mask) + + img_normed = self.minmax_norm(image.numpy()[0, :, :, :]) + + print(img_normed.shape) + rgb_image = np.stack((img_normed[0, :, :], + img_normed[3, :, :], + img_normed[2, :, :]), + axis=-1) + + img_recon = self.minmax_norm(img_recon) + rgb_image_recon = np.stack((img_recon[0, :, :], + img_recon[3, :, :], + img_recon[2, :, :]), + axis=-1) + + rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon) + rgb_image_masked = np.where(mask == 1, 0, rgb_image) + rgb_recon_masked = rgb_masked# self.minmax_norm(rgb_masked) + + return rgb_image, rgb_image_masked, rgb_recon_masked + + +def infer(array_input: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + + masked_input = np.random.rand(256, 256, 3) + + output = np.random.rand(256, 256, 3) + + return masked_input, output + +if __name__ == '__main__': + inferenceModel = InferenceModel() + + image = np.load('data/images/sv-demo-mod09ga-11.npy') + print(image.shape) + image = np.moveaxis(image, 0, 2) + print(image.shape) + + inference = inferenceModel.infer(image) \ No newline at end of file diff --git a/model.sav b/model.sav new file mode 100644 index 0000000000000000000000000000000000000000..6a6a5168ed804b1e41e8bf0d49832f280972b716 --- /dev/null +++ b/model.sav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:884d0f2d5dabc20e541b43a5d69a815a63339631b4ed9a0121bf0da7cd8450ff +size 382621982 diff --git a/pytorch-caney/.readthedocs.yaml b/pytorch-caney/.readthedocs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38593578d9b18a71c229874b3ab2a3368e065bb4 --- /dev/null +++ b/pytorch-caney/.readthedocs.yaml @@ -0,0 +1,9 @@ +version: 2 + +build: + os: "ubuntu-20.04" + tools: + python: "3.8" + +sphinx: + fail_on_warning: true diff --git a/pytorch-caney/CODE_OF_CONDUCT.md b/pytorch-caney/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..e35c2222d797e38230f1bef8a349a44571edd53c --- /dev/null +++ b/pytorch-caney/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +support@nccs.nasa.gov. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/pytorch-caney/CONTRIBUTING.md b/pytorch-caney/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..d0e9ec87947fa51f76b86484ab61cc3ddc6f9f58 --- /dev/null +++ b/pytorch-caney/CONTRIBUTING.md @@ -0,0 +1,173 @@ +# Contributing + +When contributing to this repository, please first discuss the change you wish to make via issue, +email, or any other method with the owners of this repository before making a change. + +Please note we have a code of conduct, please follow it in all your interactions with the project. + +## Pull Request Process + +1. Ensure any install or build dependencies are removed before the end of the layer when doing a + build. +2. Update the README.md with details of changes to the interface, this includes new environment + variables, exposed ports, useful file locations and container parameters. +3. Increase the version numbers in any examples files and the README.md to the new version that this + Pull Request would represent. The versioning scheme we use is [SemVer](http://semver.org/). +4. Regenerate any additional documentation using PDOC (usage details listed below). +5. Document the proposed changes in the CHANGELOG.md file. +6. You may submit your merge request for review and the change will be reviewed. + +## Code of Conduct + +### Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, gender identity and expression, level of experience, +nationality, personal appearance, race, religion, or sexual identity and +orientation. + +### Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +### Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +### Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +### Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at [INSERT EMAIL ADDRESS]. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +### Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at [http://contributor-covenant.org/version/1/4][version] + +[homepage]: http://contributor-covenant.org +[version]: http://contributor-covenant.org/version/1/4/ + +## Appendix + +### Generating Documentation + +This repository follows semi-automatic documentation generation. The following +is an example of how to generate documentation for a single module. + +```bash +conda activate tensorflow-caney +pdoc --html tensorflow-caney/raster.py --force +``` + +### Linting + +This project uses flake8 for PREP8 linting and format. Every submodule should include +a test section in the tests directory. Refer to the text directory for more examples. +The Python unittests library is used for these purposes. + +### Documenting Methods + +The following documentation format should be followed below each method to allow for +explicit semi-automatic documentation generation. + +```bash + """ + Read raster and append data to existing Raster object + Args: + filename (str): raster filename to read from + bands (str list): list of bands to append to object, e.g ['Red'] + chunks_band (int): integer to map object to memory, z + chunks_x (int): integer to map object to memory, x + chunks_y (int): integer to map object to memory, y + Return: + raster (raster object): raster object to manipulate rasters + ---------- + Example + ---------- + raster.readraster(filename, bands) + """ +``` + +### Format of CHANGELOG + +The following describes the format for each CHANGELOG release. If there are no contributions +in any of the sections, they are removed from the description. + +```bash +## [0.0.3] - 2020-12-14 + +### Added +- Short description + +### Fixed +- Short description + +### Changed +- Short description + +### Removed +- Short description + +### Approved +Approver Name, Email +``` + +### Example Using Container in ADAPT + +```bash +module load singularity +singularity shell -B $your_mounts --nv tensorflow-caney +``` + +### Current Workflow + +```bash +module load singularity +singularity shell --nv -B /lscratch,/css,/explore/nobackup/projects/ilab,/explore/nobackup/people /explore/nobackup/projects/ilab/containers/tensorflow-caney-2022.12 +export PYTHONPATH="/explore/nobackup/people/jacaraba/development/tensorflow-caney:/adapt/nobackup/people/jacaraba/development/vhr-cnn-chm" +``` diff --git a/pytorch-caney/LICENSE b/pytorch-caney/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/pytorch-caney/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/pytorch-caney/LICENSE.md b/pytorch-caney/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/pytorch-caney/LICENSE.md @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/pytorch-caney/README.md b/pytorch-caney/README.md new file mode 100644 index 0000000000000000000000000000000000000000..344f54613fefdd624d928e3cf1bbcc66fda7db13 --- /dev/null +++ b/pytorch-caney/README.md @@ -0,0 +1,132 @@ +# pytorch-caney + +Python package for lots of Pytorch tools. + +[![DOI](https://zenodo.org/badge/472450059.svg)](https://zenodo.org/badge/latestdoi/472450059) +![CI Workflow](https://github.com/nasa-nccs-hpda/pytorch-caney/actions/workflows/ci.yml/badge.svg) +![CI to DockerHub ](https://github.com/nasa-nccs-hpda/pytorch-caney/actions/workflows/dockerhub.yml/badge.svg) +![Code style: PEP8](https://github.com/nasa-nccs-hpda/pytorch-caney/actions/workflows/lint.yml/badge.svg) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![Coverage Status](https://coveralls.io/repos/github/nasa-nccs-hpda/pytorch-caney/badge.svg?branch=main)](https://coveralls.io/github/nasa-nccs-hpda/pytorch-caney?branch=main) + +## Documentation + +- Latest: https://nasa-nccs-hpda.github.io/pytorch-caney/latest + +## Objectives + +- Library to process remote sensing imagery using GPU and CPU parallelization. +- Machine Learning and Deep Learning image classification and regression. +- Agnostic array and vector-like data structures. +- User interface environments via Notebooks for easy to use AI/ML projects. +- Example notebooks for quick AI/ML start with your own data. + +## Installation + +The following library is intended to be used to accelerate the development of data science products +for remote sensing satellite imagery, or any other applications. pytorch-caney can be installed +by itself, but instructions for installing the full environments are listed under the requirements +directory so projects, examples, and notebooks can be run. + +Note: PIP installations do not include CUDA libraries for GPU support. Make sure NVIDIA libraries +are installed locally in the system if not using conda/mamba. + +```bash +module load singularity # if a module needs to be loaded +singularity build --sandbox pytorch-caney-container docker://nasanccs/pytorch-caney:latest +``` + +## Why Caney? + +"Caney" means longhouse in Taíno. + +## Contributors + +- Jordan Alexis Caraballo-Vega, jordan.a.caraballo-vega@nasa.gov +- Caleb Spradlin, caleb.s.spradlin@nasa.gov + +## Contributing + +Please see our [guide for contributing to pytorch-caney](CONTRIBUTING.md). + +## SatVision + +| name | pretrain | resolution | #params | +| :---: | :---: | :---: | :---: | +| SatVision-B | MODIS-1.9-M | 192x192 | 84.5M | + +## SatVision Datasets + +| name | bands | resolution | #chips | +| :---: | :---: | :---: | :---: | +| MODIS-Small | 7 | 128x128 | 1,994,131 | + +## MODIS Surface Reflectance (MOD09GA) Band Details + +| Band Name | Bandwidth | +| :------------: | :-----------: | +| sur_refl_b01_1 | 0.620 - 0.670 | +| sur_refl_b02_1 | 0.841 - 0.876 | +| sur_refl_b03_1 | 0.459 - 0.479 | +| sur_refl_b04_1 | 0.545 - 0.565 | +| sur_refl_b05_1 | 1.230 - 1.250 | +| sur_refl_b06_1 | 1.628 - 1.652 | +| sur_refl_b07_1 | 2.105 - 2.155 | + +## Pre-training with Masked Image Modeling + +To pre-train the swinv2 base model with masked image modeling pre-training, run: +```bash +torchrun --nproc_per_node pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py --cfg --dataset --data-paths --batch-size --output --enable-amp +``` + +For example to run on a compute node with 4 GPUs and a batch size of 128 on the MODIS SatVision pre-training dataset with a base swinv2 model, run: + +```bash +singularity shell --nv -B /path/to/container/pytorch-caney-container +Singularity> export PYTHONPATH=$PWD:$PWD/pytorch-caney +Singularity> torchrun --nproc_per_node 4 pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py --cfg pytorch-caney/examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml --dataset MODIS --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* --batch-size 128 --output . --enable-amp +``` + +This example script runs the exact configuration used to make the SatVision-base model pre-training with MiM and the MODIS pre-training dataset. +```bash +singularity shell --nv -B /path/to/container/pytorch-caney-container +Singularity> cd pytorch-caney/examples/satvision +Singularity> ./run_satvision_pretrain.sh +``` + +## Fine-tuning Satvision-base +To fine-tune the satvision-base pre-trained model, run: +```bash +torchrun --nproc_per_node pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py --cfg --pretrained --dataset --data-paths --batch-size --output --enable-amp +``` + +See example config files pytorch-caney/examples/satvision/finetune_satvision_base_*.yaml to see how to structure your config file for fine-tuning. + + +## Testing +For unittests, run this bash command to run linting and unit test runs. This will execute unit tests and linting in a temporary venv environment only used for testing. +```bash +git clone git@github.com:nasa-nccs-hpda/pytorch-caney.git +cd pytorch-caney; bash test.sh +``` +or run unit tests directly with container or anaconda env + +```bash +git clone git@github.com:nasa-nccs-hpda/pytorch-caney.git +singularity build --sandbox pytorch-caney-container docker://nasanccs/pytorch-caney:latest +singularity shell --nv -B /path/to/container/pytorch-caney-container +cd pytorch-caney; python -m unittest discover pytorch_caney/tests +``` + +```bash +git clone git@github.com:nasa-nccs-hpda/pytorch-caney.git +cd pytorch-caney; conda env create -f requirements/environment_gpu.yml; +conda activate pytorch-caney +python -m unittest discover pytorch_caney/tests +``` +## References + +- [Pytorch Lightning](https://github.com/Lightning-AI/lightning) +- [Swin Transformer](https://github.com/microsoft/Swin-Transformer) +- [SimMIM](https://github.com/microsoft/SimMIM) diff --git a/pytorch-caney/README.rst b/pytorch-caney/README.rst new file mode 100644 index 0000000000000000000000000000000000000000..91251937f7b35aa14fdc9d17ecae21ff7abf4d67 --- /dev/null +++ b/pytorch-caney/README.rst @@ -0,0 +1,164 @@ +================ +pytorch-caney +================ + +Python package for lots of Pytorch tools for geospatial science problems. + +.. image:: https://zenodo.org/badge/472450059.svg + :target: https://zenodo.org/badge/latestdoi/472450059 + +Objectives +------------ + +- Library to process remote sensing imagery using GPU and CPU parallelization. +- Machine Learning and Deep Learning image classification and regression. +- Agnostic array and vector-like data structures. +- User interface environments via Notebooks for easy to use AI/ML projects. +- Example notebooks for quick AI/ML start with your own data. + +Installation +---------------- + +The following library is intended to be used to accelerate the development of data science products +for remote sensing satellite imagery, or any other applications. pytorch-caney can be installed +by itself, but instructions for installing the full environments are listed under the requirements +directory so projects, examples, and notebooks can be run. + +Note: PIP installations do not include CUDA libraries for GPU support. Make sure NVIDIA libraries +are installed locally in the system if not using conda/mamba. + +.. code-block:: bash + + module load singularity # if a module needs to be loaded + singularity build --sandbox pytorch-caney-container docker://nasanccs/pytorch-caney:latest + + +Why Caney? +--------------- + +"Caney" means longhouse in Taíno. + +Contributors +------------- + +- Jordan Alexis Caraballo-Vega, jordan.a.caraballo-vega@nasa.gov +- Caleb Spradlin, caleb.s.spradlin@nasa.gov +- Jian Li, jian.li@nasa.gov + +Contributing +------------- + +Please see our `guide for contributing to pytorch-caney `_. + +SatVision +------------ + ++---------------+--------------+------------+------------+ +| Name | Pretrain | Resolution | Parameters | ++===============+==============+============+============+ +| SatVision-B | MODIS-1.9-M | 192x192 | 84.5M | ++---------------+--------------+------------+------------+ + +SatVision Datasets +----------------------- + ++---------------+-----------+------------+-------------+ +| Name | Bands | Resolution | Image Chips | ++===============+===========+============+=============+ +| MODIS-Small | 7 | 128x128 | 1,994,131 | ++---------------+-----------+------------+-------------+ + +MODIS Surface Reflectance (MOD09GA) Band Details +------------------------------------------------------ + ++-----------------+---------------+ +| Band Name | Bandwidth | ++=================+===============+ +| sur_refl_b01_1 | 0.620 - 0.670 | ++-----------------+---------------+ +| sur_refl_b02_1 | 0.841 - 0.876 | ++-----------------+---------------+ +| sur_refl_b03_1 | 0.459 - 0.479 | ++-----------------+---------------+ +| sur_refl_b04_1 | 0.545 - 0.565 | ++-----------------+---------------+ +| sur_refl_b05_1 | 1.230 - 1.250 | ++-----------------+---------------+ +| sur_refl_b06_1 | 1.628 - 1.652 | ++-----------------+---------------+ +| sur_refl_b07_1 | 2.105 - 2.155 | ++-----------------+---------------+ + +Pre-training with Masked Image Modeling +----------------------------------------- + +To pre-train the swinv2 base model with masked image modeling pre-training, run: + +.. code-block:: bash + + torchrun --nproc_per_node pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py --cfg --dataset --data-paths --batch-size --output --enable-amp + +For example to run on a compute node with 4 GPUs and a batch size of 128 on the MODIS SatVision pre-training dataset with a base swinv2 model, run: + +.. code-block:: bash + + singularity shell --nv -B /path/to/container/pytorch-caney-container + Singularity> export PYTHONPATH=$PWD:$PWD/pytorch-caney + Singularity> torchrun --nproc_per_node 4 pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py --cfg pytorch-caney/examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml --dataset MODIS --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* --batch-size 128 --output . --enable-amp + + +This example script runs the exact configuration used to make the SatVision-base model pre-training with MiM and the MODIS pre-training dataset. + +.. code-block:: bash + + singularity shell --nv -B /path/to/container/pytorch-caney-container + Singularity> cd pytorch-caney/examples/satvision + Singularity> ./run_satvision_pretrain.sh + + +Fine-tuning Satvision-base +----------------------------- + +To fine-tune the satvision-base pre-trained model, run: + +.. code-block:: bash + + torchrun --nproc_per_node pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py --cfg --pretrained --dataset --data-paths --batch-size --output --enable-amp + +See example config files pytorch-caney/examples/satvision/finetune_satvision_base_*.yaml to see how to structure your config file for fine-tuning. + + +Testing +------------ + +For unittests, run this bash command to run linting and unit test runs. This will execute unit tests and linting in a temporary venv environment only used for testing. + +.. code-block:: bash + + git clone git@github.com:nasa-nccs-hpda/pytorch-caney.git + cd pytorch-caney; bash test.sh + + +or run unit tests directly with container or anaconda env + +.. code-block:: bash + + git clone git@github.com:nasa-nccs-hpda/pytorch-caney.git + singularity build --sandbox pytorch-caney-container docker://nasanccs/pytorch-caney:latest + singularity shell --nv -B /path/to/container/pytorch-caney-container + cd pytorch-caney; python -m unittest discover pytorch_caney/tests + +.. code-block:: bash + + git clone git@github.com:nasa-nccs-hpda/pytorch-caney.git + cd pytorch-caney; conda env create -f requirements/environment_gpu.yml; + conda activate pytorch-caney + python -m unittest discover pytorch_caney/tests + + +References +------------ + +- `Pytorch Lightning `_ +- `Swin Transformer `_ +- `SimMIM `_ diff --git a/pytorch-caney/docs/Makefile b/pytorch-caney/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..d4bb2cbb9eddb1bb1b4f366623044af8e4830919 --- /dev/null +++ b/pytorch-caney/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/pytorch-caney/docs/conf.py b/pytorch-caney/docs/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..7c505401b867dba9ac72a8119b038f2f60b74895 --- /dev/null +++ b/pytorch-caney/docs/conf.py @@ -0,0 +1,52 @@ +import os +import sys + +sys.path.insert(0, os.path.abspath('..')) + +import pytorch_caney # noqa: E402 + +project = 'pytorch-caney' +copyright = '2023, Jordan A. Caraballo-Vega' +author = 'Jordan A. Caraballo-Vega' + +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx_autodoc_typehints', + 'jupyter_sphinx.execute', + "sphinx.ext.intersphinx", + "sphinx.ext.viewcode", + "sphinx.ext.napoleon", + "sphinx_click.ext", + "sphinx.ext.githubpages", + "nbsphinx", +] + +intersphinx_mapping = { + "pyproj": ("https://pyproj4.github.io/pyproj/stable/", None), + "rasterio": ("https://rasterio.readthedocs.io/en/stable/", None), + "xarray": ("http://xarray.pydata.org/en/stable/", None), +} + +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +master_doc = "index" + +version = release = pytorch_caney.__version__ + +pygments_style = "sphinx" + +todo_include_todos = False + +html_theme = 'sphinx_rtd_theme' +html_logo = 'static/DSG_LOGO_REDESIGN.png' + +myst_enable_extensions = [ + "amsmath", + "colon_fence", + "deflist", + "dollarmath", + "html_image", +] + +myst_url_schemes = ("http", "https", "mailto") diff --git a/pytorch-caney/docs/examples.rst b/pytorch-caney/docs/examples.rst new file mode 100644 index 0000000000000000000000000000000000000000..4c4955572a0d17e598372f17714c0bdac4e0cf76 --- /dev/null +++ b/pytorch-caney/docs/examples.rst @@ -0,0 +1,3 @@ +.. toctree:: + :maxdepth: 2 + :caption: Contents: diff --git a/pytorch-caney/docs/index.rst b/pytorch-caney/docs/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..0e6edb0fa3d904f840f54148b090ffcf2a5ca7d7 --- /dev/null +++ b/pytorch-caney/docs/index.rst @@ -0,0 +1,22 @@ +.. pytorch-caney's documentation master file, created by + sphinx-quickstart on Fri Jun 23 11:32:18 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to pytorch-caney's documentation! +========================================= + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + readme + examples + modules + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/pytorch-caney/docs/make.bat b/pytorch-caney/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..32bb24529f92346af26219baed295b7488b77534 --- /dev/null +++ b/pytorch-caney/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/pytorch-caney/docs/modules.rst b/pytorch-caney/docs/modules.rst new file mode 100644 index 0000000000000000000000000000000000000000..b68725603deb689ab387ff7eba787ccfbae8eb44 --- /dev/null +++ b/pytorch-caney/docs/modules.rst @@ -0,0 +1,284 @@ +pytorch-caney package +======================== + +pytorch_caney.config +---------------------- + +.. automodule:: pytorch_caney.config + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.data.datamodules.finetune_datamodule +---------------------- + +.. automodule:: pytorch_caney.data.datamodules.finetune_datamodule + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.data.datamodules.mim_datamodule +---------------------- + +.. automodule:: pytorch_caney.data.datamodules.mim_datamodule + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.data.datamodules.segmentation_datamodule +---------------------- + +.. automodule:: pytorch_caney.data.datamodules.segmentation_datamodule + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.data.datamodules.simmim_datamodule +---------------------- + +.. automodule:: pytorch_caney.data.datamodules.simmim_datamodule + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.data.datasets.classification_dataset +---------------------- + +.. automodule:: pytorch_caney.data.datasets.classification_dataset + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.data.datasets.modis_dataset +---------------------- + +.. automodule:: pytorch_caney.data.datasets.modis_dataset + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.data.transforms +---------------------- + +.. automodule:: pytorch_caney.data.transforms + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.data.utils +---------------------- + +.. automodule:: pytorch_caney.data.utils + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.inference +---------------------- + +.. automodule:: pytorch_caney.inference + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.loss.build +---------------------- + +.. automodule:: pytorch_caney.loss.build + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.loss.utils +---------------------- + +.. automodule:: pytorch_caney.loss.utils + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.lr_scheduler +---------------------- + +.. automodule:: pytorch_caney.lr_scheduler + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.metrics +---------------------- + +.. automodule:: pytorch_caney.metrics + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.models.decoders.unet_decoder +---------------------- + +.. automodule:: pytorch_caney.models.decoders.unet_decoder + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.models.mim.mim +---------------------- + +.. automodule:: pytorch_caney.models.mim.mim + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.models.simmim.simmim +---------------------- + +.. automodule:: pytorch_caney.models.simmim.simmim + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.models.build +---------------------- + +.. automodule:: pytorch_caney.models.build + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.models.maskrcnn_model +---------------------- + +.. automodule:: pytorch_caney.models.maskrcnn_model + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.models.swinv2_model +---------------------- + +.. automodule:: pytorch_caney.models.swinv2_model + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.models.unet_model +---------------------- + +.. automodule:: pytorch_caney.models.unet_model + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.models.unet_swin_model +---------------------- + +.. automodule:: pytorch_caney.models.unet_swin_model + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.network.attention +---------------------- + +.. automodule:: pytorch_caney.network.attention + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.network.mlp +---------------------- + +.. automodule:: pytorch_caney.network.mlp + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.pipelines.finetuning.finetune +---------------------- + +.. automodule:: pytorch_caney.pipelines.finetuning.finetune + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.pipelines.pretraining.mim +---------------------- + +.. automodule:: pytorch_caney.pipelines.pretraining.mim + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.pipelines.modis_segmentation +---------------------- + +.. automodule:: pytorch_caney.pipelines.modis_segmentation + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.processing +---------------------- + +.. automodule:: pytorch_caney.processing + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.ptc_logging +---------------------- + +.. automodule:: pytorch_caney.ptc_logging + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.training.fine_tuning +---------------------- + +.. automodule:: pytorch_caney.training.fine_tuning + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.training.mim_utils +---------------------- + +.. automodule:: pytorch_caney.training.mim_utils + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.training.pre_training +---------------------- + +.. automodule:: pytorch_caney.training.pre_training + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.training.simmim_utils +---------------------- + +.. automodule:: pytorch_caney.training.simmim_utils + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.training.utils +---------------------- + +.. automodule:: pytorch_caney.training.utils + :members: + :undoc-members: + :show-inheritance: + +pytorch_caney.utils +---------------------- + +.. automodule:: pytorch_caney.utils + :members: + :undoc-members: + :show-inheritance: + + diff --git a/pytorch-caney/docs/pytorch_caney.rst b/pytorch-caney/docs/pytorch_caney.rst new file mode 100644 index 0000000000000000000000000000000000000000..ec3ad23a5da4480759d3084b1d9001ef98b63608 --- /dev/null +++ b/pytorch-caney/docs/pytorch_caney.rst @@ -0,0 +1,31 @@ +pytorch-caney +------------------ + +Python package for lots of Pytorch tools for geospatial science problems. + +Installation +----------------- + +Install with pip +:: + + pip install pytorch-caney + +API +---- + +.. automodule:: pytorch_caney + :members: + +Authors +---------- + +Jordan A. Caraballo-Vega, jordan.a.caraballo-vega@nasa.gov +Caleb S. Spradlin, caleb.s.spradlin@nasa.gov +Jian Li, jian.li@nasa.gov + +License +--------------- + +The package is released under the `MIT +License `__. diff --git a/pytorch-caney/docs/readme.rst b/pytorch-caney/docs/readme.rst new file mode 100644 index 0000000000000000000000000000000000000000..72a33558153fb57def85612b021ec596ef2a51b9 --- /dev/null +++ b/pytorch-caney/docs/readme.rst @@ -0,0 +1 @@ +.. include:: ../README.rst diff --git a/pytorch-caney/docs/requirements.txt b/pytorch-caney/docs/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4fc9df0651db16f818abd98b8f13cc4a9c3e2b4d --- /dev/null +++ b/pytorch-caney/docs/requirements.txt @@ -0,0 +1,4 @@ +sphinx +sphinx-autodoc-typehints +jupyter-sphinx +sphinx-rtd-theme diff --git a/pytorch-caney/docs/source/index.rst b/pytorch-caney/docs/source/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..963fbcc4f1efff381333c7ad07bad7b554eae7d6 --- /dev/null +++ b/pytorch-caney/docs/source/index.rst @@ -0,0 +1,20 @@ +.. tensorflow-caney documentation master file, created by + sphinx-quickstart on Fri Jan 13 06:59:19 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to tensorflow-caney's documentation! +============================================ + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/pytorch-caney/docs/static/DSG_LOGO_REDESIGN.png b/pytorch-caney/docs/static/DSG_LOGO_REDESIGN.png new file mode 100644 index 0000000000000000000000000000000000000000..79e45381c3f73ffc7c33aeeec0571d9087abb4e6 Binary files /dev/null and b/pytorch-caney/docs/static/DSG_LOGO_REDESIGN.png differ diff --git a/pytorch-caney/examples/satvision/finetune_satvision_base_landcover5class_192_window12_100ep.yaml b/pytorch-caney/examples/satvision/finetune_satvision_base_landcover5class_192_window12_100ep.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f41c644eef076c36a097fdf6f4bc2b3e1cf3ec9 --- /dev/null +++ b/pytorch-caney/examples/satvision/finetune_satvision_base_landcover5class_192_window12_100ep.yaml @@ -0,0 +1,33 @@ +MODEL: + TYPE: swinv2 + DECODER: unet + NAME: satvision_finetune_lc5class + DROP_PATH_RATE: 0.1 + NUM_CLASSES: 5 + SWINV2: + IN_CHANS: 7 + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 14 + PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] +DATA: + IMG_SIZE: 224 + DATASET: MODISLC5 + MASK_PATCH_SIZE: 32 + MASK_RATIO: 0.6 +LOSS: + NAME: 'tversky' + MODE: 'multiclass' + ALPHA: 0.4 + BETA: 0.6 +TRAIN: + EPOCHS: 100 + WARMUP_EPOCHS: 10 + BASE_LR: 1e-4 + WARMUP_LR: 5e-7 + WEIGHT_DECAY: 0.01 + LAYER_DECAY: 0.8 +PRINT_FREQ: 100 +SAVE_FREQ: 5 +TAG: satvision_finetune_land_cover_5class_swinv2_satvision_192_window12__800ep \ No newline at end of file diff --git a/pytorch-caney/examples/satvision/finetune_satvision_base_landcover9class_192_window12_100ep.yaml b/pytorch-caney/examples/satvision/finetune_satvision_base_landcover9class_192_window12_100ep.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2e96121afffb778403c83033c6864b3ed749379c --- /dev/null +++ b/pytorch-caney/examples/satvision/finetune_satvision_base_landcover9class_192_window12_100ep.yaml @@ -0,0 +1,33 @@ +MODEL: + TYPE: swinv2 + DECODER: unet + NAME: satvision_finetune_lc9class + DROP_PATH_RATE: 0.1 + NUM_CLASSES: 9 + SWINV2: + IN_CHANS: 7 + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 14 + PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] +DATA: + IMG_SIZE: 224 + DATASET: MODISLC5 + MASK_PATCH_SIZE: 32 + MASK_RATIO: 0.6 +LOSS: + NAME: 'tversky' + MODE: 'multiclass' + ALPHA: 0.4 + BETA: 0.6 +TRAIN: + EPOCHS: 100 + WARMUP_EPOCHS: 10 + BASE_LR: 1e-4 + WARMUP_LR: 5e-7 + WEIGHT_DECAY: 0.01 + LAYER_DECAY: 0.8 +PRINT_FREQ: 100 +SAVE_FREQ: 5 +TAG: satvision_finetune_land_cover_9class_swinv2_satvision_192_window12__800ep \ No newline at end of file diff --git a/pytorch-caney/examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml b/pytorch-caney/examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4c188bf7af7b84ca729fd52e6c66c2d295e872af --- /dev/null +++ b/pytorch-caney/examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml @@ -0,0 +1,27 @@ +MODEL: + TYPE: swinv2 + NAME: mim_satvision_pretrain + DROP_PATH_RATE: 0.1 + SWINV2: + IN_CHANS: 7 + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 12 +DATA: + IMG_SIZE: 192 + MASK_PATCH_SIZE: 32 + MASK_RATIO: 0.6 +TRAIN: + EPOCHS: 800 + WARMUP_EPOCHS: 10 + BASE_LR: 1e-4 + WARMUP_LR: 5e-7 + WEIGHT_DECAY: 0.05 + LR_SCHEDULER: + NAME: 'multistep' + GAMMA: 0.1 + MULTISTEPS: [700,] +PRINT_FREQ: 100 +SAVE_FREQ: 5 +TAG: mim_pretrain_swinv2_satvision_192_window12__800ep \ No newline at end of file diff --git a/pytorch-caney/examples/satvision/run_satvision_finetune_lc_fiveclass.sh b/pytorch-caney/examples/satvision/run_satvision_finetune_lc_fiveclass.sh new file mode 100755 index 0000000000000000000000000000000000000000..155abf610525f864f78d1b47e740592bd6ee1fe4 --- /dev/null +++ b/pytorch-caney/examples/satvision/run_satvision_finetune_lc_fiveclass.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +#SBATCH -J finetune_satvision_lc5 +#SBATCH -t 3-00:00:00 +#SBATCH -G 4 +#SBATCH -N 1 + + +export PYTHONPATH=$PWD:../../../:../../../pytorch-caney +export NGPUS=8 + +torchrun --nproc_per_node $NGPUS \ + ../../../pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \ + --cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \ + --pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \ + --dataset MODISLC9 \ + --data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_9classes_224 \ + --batch-size 4 \ + --output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \ + --enable-amp \ No newline at end of file diff --git a/pytorch-caney/examples/satvision/run_satvision_finetune_lc_nineclass.sh b/pytorch-caney/examples/satvision/run_satvision_finetune_lc_nineclass.sh new file mode 100755 index 0000000000000000000000000000000000000000..7008967594c1f4640c154b2c532e9a67cc1bcad1 --- /dev/null +++ b/pytorch-caney/examples/satvision/run_satvision_finetune_lc_nineclass.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +#SBATCH -J finetune_satvision_lc9 +#SBATCH -t 3-00:00:00 +#SBATCH -G 4 +#SBATCH -N 1 + + +export PYTHONPATH=$PWD:../../../:../../../pytorch-caney +export NGPUS=8 + +torchrun --nproc_per_node $NGPUS \ + ../../../pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \ + --cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \ + --pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \ + --dataset MODISLC9 \ + --data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_5classes_224 \ + --batch-size 4 \ + --output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \ + --enable-amp \ No newline at end of file diff --git a/pytorch-caney/examples/satvision/run_satvision_pretrain.sh b/pytorch-caney/examples/satvision/run_satvision_pretrain.sh new file mode 100755 index 0000000000000000000000000000000000000000..0ff9598ac39ae4df507ff5027e6229dad9cbd6de --- /dev/null +++ b/pytorch-caney/examples/satvision/run_satvision_pretrain.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +#SBATCH -J pretrain_satvision_swinv2 +#SBATCH -t 3-00:00:00 +#SBATCH -G 4 +#SBATCH -N 1 + + +export PYTHONPATH=$PWD:../../../:../../../pytorch-caney +export NGPUS=4 + +torchrun --nproc_per_node $NGPUS \ + ../../../pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py \ + --cfg mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml \ + --dataset MODIS \ + --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* \ + --batch-size 128 \ + --output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/trf/transformer/models \ + --enable-amp \ No newline at end of file diff --git a/pytorch-caney/pyproject.toml b/pytorch-caney/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..931e61b78dddd00f9c9d2d3ca2a357c9baea869f --- /dev/null +++ b/pytorch-caney/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +# Minimum requirements for the build system to execute. +requires = ["setuptools", "wheel"] + +[tool.black] +target_version = ['py39'] diff --git a/pytorch-caney/pytorch_caney/__init__.py b/pytorch-caney/pytorch_caney/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..3dc1f76bc69e3f559bee6253b24fc93acee9e1f9 --- /dev/null +++ b/pytorch-caney/pytorch_caney/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/pytorch-caney/pytorch_caney/__pycache__/__init__.cpython-310.pyc b/pytorch-caney/pytorch_caney/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8538a925ab48273c5ef5073ef9b2d92516deb37 Binary files /dev/null and b/pytorch-caney/pytorch_caney/__pycache__/__init__.cpython-310.pyc differ diff --git a/pytorch-caney/pytorch_caney/config.py b/pytorch-caney/pytorch_caney/config.py new file mode 100644 index 0000000000000000000000000000000000000000..d35ac2a71167a62dddba6806943a56089dd452b7 --- /dev/null +++ b/pytorch-caney/pytorch_caney/config.py @@ -0,0 +1,226 @@ +import os +import yaml +from yacs.config import CfgNode as CN + +_C = CN() + +# Base config files +_C.BASE = [''] + +# ----------------------------------------------------------------------------- +# Data settings +# ----------------------------------------------------------------------------- +_C.DATA = CN() +# Batch size for a single GPU, could be overwritten by command line argument +_C.DATA.BATCH_SIZE = 128 +# Path(s) to dataset, could be overwritten by command line argument +_C.DATA.DATA_PATHS = [''] +# Dataset name +_C.DATA.DATASET = 'MODIS' +# Input image size +_C.DATA.IMG_SIZE = 224 +# Interpolation to resize image (random, bilinear, bicubic) +_C.DATA.INTERPOLATION = 'bicubic' +# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. +_C.DATA.PIN_MEMORY = True +# Number of data loading threads +_C.DATA.NUM_WORKERS = 8 +# [SimMIM] Mask patch size for MaskGenerator +_C.DATA.MASK_PATCH_SIZE = 32 +# [SimMIM] Mask ratio for MaskGenerator +_C.DATA.MASK_RATIO = 0.6 + +# ----------------------------------------------------------------------------- +# Model settings +# ----------------------------------------------------------------------------- +_C.MODEL = CN() +# Model type +_C.MODEL.TYPE = 'swinv2' +# Decoder type +_C.MODEL.DECODER = None +# Model name +_C.MODEL.NAME = 'swinv2_base_patch4_window7_224' +# Pretrained weight from checkpoint, could be from previous pre-training +# could be overwritten by command line argument +_C.MODEL.PRETRAINED = '' +# Checkpoint to resume, could be overwritten by command line argument +_C.MODEL.RESUME = '' +# Number of classes, overwritten in data preparation +_C.MODEL.NUM_CLASSES = 17 +# Dropout rate +_C.MODEL.DROP_RATE = 0.0 +# Drop path rate +_C.MODEL.DROP_PATH_RATE = 0.1 + +# Swin Transformer V2 parameters +_C.MODEL.SWINV2 = CN() +_C.MODEL.SWINV2.PATCH_SIZE = 4 +_C.MODEL.SWINV2.IN_CHANS = 3 +_C.MODEL.SWINV2.EMBED_DIM = 96 +_C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2] +_C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24] +_C.MODEL.SWINV2.WINDOW_SIZE = 7 +_C.MODEL.SWINV2.MLP_RATIO = 4. +_C.MODEL.SWINV2.QKV_BIAS = True +_C.MODEL.SWINV2.APE = False +_C.MODEL.SWINV2.PATCH_NORM = True +_C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0] + +# ----------------------------------------------------------------------------- +# Training settings +# ----------------------------------------------------------------------------- +_C.LOSS = CN() +_C.LOSS.NAME = 'tversky' +_C.LOSS.MODE = 'multiclass' +_C.LOSS.CLASSES = None +_C.LOSS.LOG = False +_C.LOSS.LOGITS = True +_C.LOSS.SMOOTH = 0.0 +_C.LOSS.IGNORE_INDEX = None +_C.LOSS.EPS = 1e-7 +_C.LOSS.ALPHA = 0.5 +_C.LOSS.BETA = 0.5 +_C.LOSS.GAMMA = 1.0 + +# ----------------------------------------------------------------------------- +# Training settings +# ----------------------------------------------------------------------------- +_C.TRAIN = CN() +_C.TRAIN.START_EPOCH = 0 +_C.TRAIN.EPOCHS = 300 +_C.TRAIN.WARMUP_EPOCHS = 20 +_C.TRAIN.WEIGHT_DECAY = 0.05 +_C.TRAIN.BASE_LR = 5e-4 +_C.TRAIN.WARMUP_LR = 5e-7 +_C.TRAIN.MIN_LR = 5e-6 +# Clip gradient norm +_C.TRAIN.CLIP_GRAD = 5.0 +# Auto resume from latest checkpoint +_C.TRAIN.AUTO_RESUME = True +# Gradient accumulation steps +# could be overwritten by command line argument +_C.TRAIN.ACCUMULATION_STEPS = 0 +# Whether to use gradient checkpointing to save memory +# could be overwritten by command line argument +_C.TRAIN.USE_CHECKPOINT = False + +# LR scheduler +_C.TRAIN.LR_SCHEDULER = CN() +_C.TRAIN.LR_SCHEDULER.NAME = 'cosine' +# Epoch interval to decay LR, used in StepLRScheduler +_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 +# LR decay rate, used in StepLRScheduler +_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 +# Gamma / Multi steps value, used in MultiStepLRScheduler +_C.TRAIN.LR_SCHEDULER.GAMMA = 0.1 +_C.TRAIN.LR_SCHEDULER.MULTISTEPS = [] + +# Optimizer +_C.TRAIN.OPTIMIZER = CN() +_C.TRAIN.OPTIMIZER.NAME = 'adamw' +# Optimizer Epsilon +_C.TRAIN.OPTIMIZER.EPS = 1e-8 +# Optimizer Betas +_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) +# SGD momentum +_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 + +# [SimMIM] Layer decay for fine-tuning +_C.TRAIN.LAYER_DECAY = 1.0 + + +# ----------------------------------------------------------------------------- +# Testing settings +# ----------------------------------------------------------------------------- +_C.TEST = CN() +# Whether to use center crop when testing +_C.TEST.CROP = True + +# ----------------------------------------------------------------------------- +# Misc +# ----------------------------------------------------------------------------- +# Whether to enable pytorch amp, overwritten by command line argument +_C.ENABLE_AMP = False +# Enable Pytorch automatic mixed precision (amp). +_C.AMP_ENABLE = True +# Path to output folder, overwritten by command line argument +_C.OUTPUT = '' +# Tag of experiment, overwritten by command line argument +_C.TAG = 'pt-caney-default-tag' +# Frequency to save checkpoint +_C.SAVE_FREQ = 1 +# Frequency to logging info +_C.PRINT_FREQ = 10 +# Fixed random seed +_C.SEED = 42 +# Perform evaluation only, overwritten by command line argument +_C.EVAL_MODE = False + + +def _update_config_from_file(config, cfg_file): + config.defrost() + with open(cfg_file, 'r') as f: + yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) + + for cfg in yaml_cfg.setdefault('BASE', ['']): + if cfg: + _update_config_from_file( + config, os.path.join(os.path.dirname(cfg_file), cfg) + ) + print('=> merge config from {}'.format(cfg_file)) + config.merge_from_file(cfg_file) + config.freeze() + + +def update_config(config, args): + _update_config_from_file(config, args.cfg) + + config.defrost() + + def _check_args(name): + if hasattr(args, name) and eval(f'args.{name}'): + return True + return False + + # merge from specific arguments + if _check_args('batch_size'): + config.DATA.BATCH_SIZE = args.batch_size + if _check_args('data_paths'): + config.DATA.DATA_PATHS = args.data_paths + if _check_args('dataset'): + config.DATA.DATASET = args.dataset + if _check_args('resume'): + config.MODEL.RESUME = args.resume + if _check_args('pretrained'): + config.MODEL.PRETRAINED = args.pretrained + if _check_args('resume'): + config.MODEL.RESUME = args.resume + if _check_args('accumulation_steps'): + config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps + if _check_args('use_checkpoint'): + config.TRAIN.USE_CHECKPOINT = True + if _check_args('disable_amp'): + config.AMP_ENABLE = False + if _check_args('output'): + config.OUTPUT = args.output + if _check_args('tag'): + config.TAG = args.tag + if _check_args('eval'): + config.EVAL_MODE = True + if _check_args('enable_amp'): + config.ENABLE_AMP = args.enable_amp + + # output folder + config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) + + config.freeze() + + +def get_config(args): + """Get a yacs CfgNode object with default values.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + config = _C.clone() + update_config(config, args) + + return config diff --git a/pytorch-caney/pytorch_caney/console/__init__.py b/pytorch-caney/pytorch_caney/console/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/console/cli.py b/pytorch-caney/pytorch_caney/console/cli.py new file mode 100755 index 0000000000000000000000000000000000000000..e02d5717481809ae5be3af713ab25138e360e5f7 --- /dev/null +++ b/pytorch-caney/pytorch_caney/console/cli.py @@ -0,0 +1,62 @@ +from pytorch_lightning.utilities.cli import LightningCLI + +import torch + + +class TerraGPULightningCLI(LightningCLI): + + def add_arguments_to_parser(self, parser): + + # Trainer - performance + parser.set_defaults({"trainer.accelerator": "auto"}) + parser.set_defaults({"trainer.devices": "auto"}) + parser.set_defaults({"trainer.auto_select_gpus": True}) + parser.set_defaults({"trainer.precision": 32}) + + # Trainer - training + parser.set_defaults({"trainer.max_epochs": 500}) + parser.set_defaults({"trainer.min_epochs": 1}) + parser.set_defaults({"trainer.detect_anomaly": True}) + parser.set_defaults({"trainer.logger": True}) + parser.set_defaults({"trainer.default_root_dir": "output_model"}) + + # Trainer - optimizer - TODO + _ = { + "class_path": torch.optim.Adam, + "init_args": { + "lr": 0.01 + } + } + + # Trainer - callbacks + default_callbacks = [ + {"class_path": "pytorch_lightning.callbacks.DeviceStatsMonitor"}, + { + "class_path": "pytorch_lightning.callbacks.EarlyStopping", + "init_args": { + "monitor": "val_loss", + "patience": 5, + "mode": "min" + } + }, + # { + # "class_path": "pytorch_lightning.callbacks.ModelCheckpoint", + # "init_args": { + # "dirpath": "output_model", + # "monitor": "val_loss", + # "auto_insert_metric_name": True + # } + # }, + ] + parser.set_defaults({"trainer.callbacks": default_callbacks}) + + # { + # "class_path": "pytorch_lightning.callbacks.ModelCheckpoint", + # "init_args": { + # "dirpath": "output_model", + # "monitor": "val_loss", + # "auto_insert_metric_name": True + # } + # }, + # ] + # parser.set_defaults({"trainer.callbacks": default_callbacks}) diff --git a/pytorch-caney/pytorch_caney/console/dl_pipeline.py b/pytorch-caney/pytorch_caney/console/dl_pipeline.py new file mode 100755 index 0000000000000000000000000000000000000000..4840a9570e185a8205136ab254ced2c601e70d72 --- /dev/null +++ b/pytorch-caney/pytorch_caney/console/dl_pipeline.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# RF pipeline: preprocess, train, and predict. + +import sys +import logging + +# from terragpu import unet_model +# from terragpu.decorators import DuplicateFilter +# from terragpu.ai.deep_learning.datamodules.segmentation_datamodule \ +# import SegmentationDataModule + +from pytorch_lightning import seed_everything # , trainer +# from pytorch_lightning import LightningModule, LightningDataModule +from terragpu.ai.deep_learning.console.cli import TerraGPULightningCLI + + +# ----------------------------------------------------------------------------- +# main +# +# python rf_pipeline.py options here +# ----------------------------------------------------------------------------- +def main(): + + # ------------------------------------------------------------------------- + # Set logging + # ------------------------------------------------------------------------- + logger = logging.getLogger() + logger.setLevel(logging.INFO) + ch = logging.StreamHandler(sys.stdout) + ch.setLevel(logging.INFO) + + # Set formatter and handlers + formatter = logging.Formatter( + "%(asctime)s; %(levelname)s; %(message)s", "%Y-%m-%d %H:%M:%S") + ch.setFormatter(formatter) + logger.addHandler(ch) + + # ------------------------------------------------------------------------- + # Execute pipeline step + # ------------------------------------------------------------------------- + # Seed every library + seed_everything(1234, workers=True) + _ = TerraGPULightningCLI(save_config_callback=None) + # unet_model.UNetSegmentation, SegmentationDataModule) + + # train + # trainer = pl.Trainer() + # trainer.fit(model, datamodule=dm) + # validate + # trainer.validate(datamodule=dm) + # test + # trainer.test(datamodule=dm) + # predict + # predictions = trainer.predict(datamodule=dm) + return + + +# ----------------------------------------------------------------------------- +# Invoke the main +# ----------------------------------------------------------------------------- +if __name__ == "__main__": + sys.exit(main()) diff --git a/pytorch-caney/pytorch_caney/data/__pycache__/utils.cpython-310.pyc b/pytorch-caney/pytorch_caney/data/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1d4e6b3aa7a9380487d8fb40979189628a6893e Binary files /dev/null and b/pytorch-caney/pytorch_caney/data/__pycache__/utils.cpython-310.pyc differ diff --git a/pytorch-caney/pytorch_caney/data/datamodules/__init__.py b/pytorch-caney/pytorch_caney/data/datamodules/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/data/datamodules/finetune_datamodule.py b/pytorch-caney/pytorch_caney/data/datamodules/finetune_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..bdbed40401f0dcb9cdfa54021899042a6295aadf --- /dev/null +++ b/pytorch-caney/pytorch_caney/data/datamodules/finetune_datamodule.py @@ -0,0 +1,114 @@ +from ..datasets.modis_dataset import MODISDataset +from ..datasets.modis_lc_five_dataset import MODISLCFiveDataset +from ..datasets.modis_lc_nine_dataset import MODISLCNineDataset + +from ..transforms import TensorResizeTransform + +import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler + + +DATASETS = { + 'modis': MODISDataset, + 'modislc9': MODISLCNineDataset, + 'modislc5': MODISLCFiveDataset, + # 'modis tree': MODISTree, +} + + +def get_dataset_from_dict(dataset_name: str): + """Gets the proper dataset given a dataset name. + + Args: + dataset_name (str): name of the dataset + + Raises: + KeyError: thrown if dataset key is not present in dict + + Returns: + dataset: pytorch dataset + """ + + dataset_name = dataset_name.lower() + + try: + + dataset_to_use = DATASETS[dataset_name] + + except KeyError: + + error_msg = f"{dataset_name} is not an existing dataset" + + error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}" + + raise KeyError(error_msg) + + return dataset_to_use + + +def build_finetune_dataloaders(config, logger): + """Builds the dataloaders and datasets for a fine-tuning task. + + Args: + config: config object + logger: logging logger + + Returns: + dataloader_train: training dataloader + dataloader_val: validation dataloader + """ + + transform = TensorResizeTransform(config) + + logger.info(f'Finetuning data transform:\n{transform}') + + dataset_name = config.DATA.DATASET + + logger.info(f'Dataset: {dataset_name}') + logger.info(f'Data Paths: {config.DATA.DATA_PATHS}') + + dataset_to_use = get_dataset_from_dict(dataset_name) + + logger.info(f'Dataset obj: {dataset_to_use}') + + dataset_train = dataset_to_use(data_paths=config.DATA.DATA_PATHS, + split="train", + img_size=config.DATA.IMG_SIZE, + transform=transform) + + dataset_val = dataset_to_use(data_paths=config.DATA.DATA_PATHS, + split="val", + img_size=config.DATA.IMG_SIZE, + transform=transform) + + logger.info(f'Build dataset: train images = {len(dataset_train)}') + + logger.info(f'Build dataset: val images = {len(dataset_val)}') + + sampler_train = DistributedSampler( + dataset_train, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=True) + + sampler_val = DistributedSampler( + dataset_val, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=False) + + dataloader_train = DataLoader(dataset_train, + config.DATA.BATCH_SIZE, + sampler=sampler_train, + num_workers=config.DATA.NUM_WORKERS, + pin_memory=True, + drop_last=True) + + dataloader_val = DataLoader(dataset_val, + config.DATA.BATCH_SIZE, + sampler=sampler_val, + num_workers=config.DATA.NUM_WORKERS, + pin_memory=True, + drop_last=False) + + return dataloader_train, dataloader_val diff --git a/pytorch-caney/pytorch_caney/data/datamodules/mim_datamodule.py b/pytorch-caney/pytorch_caney/data/datamodules/mim_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..b70ee7488f4644f980329fa4a2b906c23ec75797 --- /dev/null +++ b/pytorch-caney/pytorch_caney/data/datamodules/mim_datamodule.py @@ -0,0 +1,80 @@ +from ..datasets.simmim_modis_dataset import MODISDataset + +from ..transforms import SimmimTransform + +import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data._utils.collate import default_collate + + +DATASETS = { + 'MODIS': MODISDataset, +} + + +def collate_fn(batch): + if not isinstance(batch[0][0], tuple): + return default_collate(batch) + else: + batch_num = len(batch) + ret = [] + for item_idx in range(len(batch[0][0])): + if batch[0][0][item_idx] is None: + ret.append(None) + else: + ret.append(default_collate( + [batch[i][0][item_idx] for i in range(batch_num)])) + ret.append(default_collate([batch[i][1] for i in range(batch_num)])) + return ret + + +def get_dataset_from_dict(dataset_name): + + try: + + dataset_to_use = DATASETS[dataset_name] + + except KeyError: + + error_msg = f"{dataset_name} is not an existing dataset" + + error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}" + + raise KeyError(error_msg) + + return dataset_to_use + + +def build_mim_dataloader(config, logger): + + transform = SimmimTransform(config) + + logger.info(f'Pre-train data transform:\n{transform}') + + dataset_name = config.DATA.DATASET + + dataset_to_use = get_dataset_from_dict(dataset_name) + + dataset = dataset_to_use(config, + config.DATA.DATA_PATHS, + split="train", + img_size=config.DATA.IMG_SIZE, + transform=transform) + + logger.info(f'Build dataset: train images = {len(dataset)}') + + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=True) + + dataloader = DataLoader(dataset, + config.DATA.BATCH_SIZE, + sampler=sampler, + num_workers=config.DATA.NUM_WORKERS, + pin_memory=True, + drop_last=True, + collate_fn=collate_fn) + + return dataloader diff --git a/pytorch-caney/pytorch_caney/data/datamodules/segmentation_datamodule.py b/pytorch-caney/pytorch_caney/data/datamodules/segmentation_datamodule.py new file mode 100755 index 0000000000000000000000000000000000000000..fb6d16614ea799e9a16a6cd97861825fe4e5eaa2 --- /dev/null +++ b/pytorch-caney/pytorch_caney/data/datamodules/segmentation_datamodule.py @@ -0,0 +1,164 @@ +import os +import logging +from typing import Any, Union, Optional + +import torch +from torch.utils.data import DataLoader +from torch.utils.data.dataset import random_split +from pytorch_lightning import LightningDataModule +from pytorch_lightning.utilities.cli import DATAMODULE_REGISTRY + +from terragpu.ai.deep_learning.datasets.segmentation_dataset \ + import SegmentationDataset + + +@DATAMODULE_REGISTRY +class SegmentationDataModule(LightningDataModule): + + def __init__( + self, + + # Dataset parameters + dataset_dir: str = 'dataset/', + images_regex: str = 'dataset/images/*.tif', + labels_regex: str = 'dataset/labels/*.tif', + generate_dataset: bool = True, + tile_size: int = 256, + max_patches: Union[float, int] = 100, + augment: bool = True, + chunks: dict = {'band': 1, 'x': 2048, 'y': 2048}, + input_bands: list = ['CB', 'B', 'G', 'Y', 'R', 'RE', 'N1', 'N2'], + output_bands: list = ['B', 'G', 'R'], + seed: int = 24, + normalize: bool = True, + pytorch: bool = True, + + # Datamodule parameters + val_split: float = 0.2, + test_split: float = 0.1, + num_workers: int = os.cpu_count(), + batch_size: int = 32, + shuffle: bool = True, + pin_memory: bool = False, + drop_last: bool = False, + + # Inference parameters + raster_regex: str = 'rasters/*.tif', + + *args: Any, + **kwargs: Any, + + ) -> None: + + super().__init__(*args, **kwargs) + + # Dataset parameters + self.images_regex = images_regex + self.labels_regex = labels_regex + self.dataset_dir = dataset_dir + self.generate_dataset = generate_dataset + self.tile_size = tile_size + self.max_patches = max_patches + self.augment = augment + self.chunks = chunks + self.input_bands = input_bands + self.output_bands = output_bands + self.seed = seed + self.normalize = normalize + self.pytorch = pytorch + + self.val_split = val_split + self.test_split = test_split + self.raster_regex = raster_regex + + # Performance parameters + self.batch_size = batch_size + self.num_workers = num_workers + self.shuffle = shuffle + self.pin_memory = pin_memory + self.drop_last = drop_last + + def prepare_data(self): + if self.generate_dataset: + SegmentationDataset( + images_regex=self.images_regex, + labels_regex=self.labels_regex, + dataset_dir=self.dataset_dir, + generate_dataset=self.generate_dataset, + tile_size=self.tile_size, + max_patches=self.max_patches, + augment=self.augment, + chunks=self.chunks, + input_bands=self.input_bands, + output_bands=self.output_bands, + seed=self.seed, + normalize=self.normalize, + pytorch=self.pytorch, + ) + + def setup(self, stage: Optional[str] = None): + + # Split into train, val, test + segmentation_dataset = SegmentationDataset( + images_regex=self.images_regex, + labels_regex=self.labels_regex, + dataset_dir=self.dataset_dir, + generate_dataset=False, + tile_size=self.tile_size, + max_patches=self.max_patches, + augment=self.augment, + chunks=self.chunks, + input_bands=self.input_bands, + output_bands=self.output_bands, + seed=self.seed, + normalize=self.normalize, + pytorch=self.pytorch, + ) + + # Split datasets into train, val, and test sets + val_len = round(self.val_split * len(segmentation_dataset)) + test_len = round(self.test_split * len(segmentation_dataset)) + train_len = len(segmentation_dataset) - val_len - test_len + + # Initialize datasets + self.train_set, self.val_set, self.test_set = random_split( + segmentation_dataset, lengths=[train_len, val_len, test_len], + generator=torch.Generator().manual_seed(self.seed) + ) + logging.info("Initialized datasets...") + + def train_dataloader(self) -> DataLoader: + loader = DataLoader( + self.train_set, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + ) + return loader + + def val_dataloader(self) -> DataLoader: + loader = DataLoader( + self.val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + ) + return loader + + def test_dataloader(self) -> DataLoader: + loader = DataLoader( + self.test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + ) + return loader + + def predict_dataloader(self) -> DataLoader: + raise NotImplementedError diff --git a/pytorch-caney/pytorch_caney/data/datamodules/simmim_datamodule.py b/pytorch-caney/pytorch_caney/data/datamodules/simmim_datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..b70ee7488f4644f980329fa4a2b906c23ec75797 --- /dev/null +++ b/pytorch-caney/pytorch_caney/data/datamodules/simmim_datamodule.py @@ -0,0 +1,80 @@ +from ..datasets.simmim_modis_dataset import MODISDataset + +from ..transforms import SimmimTransform + +import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data._utils.collate import default_collate + + +DATASETS = { + 'MODIS': MODISDataset, +} + + +def collate_fn(batch): + if not isinstance(batch[0][0], tuple): + return default_collate(batch) + else: + batch_num = len(batch) + ret = [] + for item_idx in range(len(batch[0][0])): + if batch[0][0][item_idx] is None: + ret.append(None) + else: + ret.append(default_collate( + [batch[i][0][item_idx] for i in range(batch_num)])) + ret.append(default_collate([batch[i][1] for i in range(batch_num)])) + return ret + + +def get_dataset_from_dict(dataset_name): + + try: + + dataset_to_use = DATASETS[dataset_name] + + except KeyError: + + error_msg = f"{dataset_name} is not an existing dataset" + + error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}" + + raise KeyError(error_msg) + + return dataset_to_use + + +def build_mim_dataloader(config, logger): + + transform = SimmimTransform(config) + + logger.info(f'Pre-train data transform:\n{transform}') + + dataset_name = config.DATA.DATASET + + dataset_to_use = get_dataset_from_dict(dataset_name) + + dataset = dataset_to_use(config, + config.DATA.DATA_PATHS, + split="train", + img_size=config.DATA.IMG_SIZE, + transform=transform) + + logger.info(f'Build dataset: train images = {len(dataset)}') + + sampler = DistributedSampler( + dataset, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=True) + + dataloader = DataLoader(dataset, + config.DATA.BATCH_SIZE, + sampler=sampler, + num_workers=config.DATA.NUM_WORKERS, + pin_memory=True, + drop_last=True, + collate_fn=collate_fn) + + return dataloader diff --git a/pytorch-caney/pytorch_caney/data/datasets/__init__.py b/pytorch-caney/pytorch_caney/data/datasets/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/data/datasets/classification_dataset.py b/pytorch-caney/pytorch_caney/data/datasets/classification_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/data/datasets/modis_dataset.py b/pytorch-caney/pytorch_caney/data/datasets/modis_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..89a4923b25c08fc2e75a3d9647910647ac6998b9 --- /dev/null +++ b/pytorch-caney/pytorch_caney/data/datasets/modis_dataset.py @@ -0,0 +1,82 @@ +import os +import random + +import numpy as np + +from torch.utils.data import Dataset + + +class MODISDataset(Dataset): + """ + MODIS Landcover 17-class pytorch fine-tuning dataset + """ + + IMAGE_PATH = os.path.join("images") + MASK_PATH = os.path.join("labels") + + def __init__( + self, + data_paths: list, + split: str, + img_size: tuple = (256, 256), + transform=None, + ): + self.img_size = img_size + self.transform = transform + self.split = split + self.data_paths = data_paths + self.img_list = [] + self.mask_list = [] + + self._init_data_paths(self.data_paths) + + # Split between train and valid set (80/20) + random_inst = random.Random(12345) # for repeatability + n_items = len(self.img_list) + idxs = set(random_inst.sample(range(n_items), n_items // 5)) + total_idxs = set(range(n_items)) + if self.split == "train": + idxs = total_idxs - idxs + + print(f'> Found {len(idxs)} patches for this dataset ({split})') + self.img_list = [self.img_list[i] for i in idxs] + self.mask_list = [self.mask_list[i] for i in idxs] + + def _init_data_paths(self, data_paths: list) -> None: + """ + Given a list of datapaths, get all filenames matching + regex from each subdatapath and compile to a single list. + """ + for data_path in data_paths: + img_path = os.path.join(data_path, self.IMAGE_PATH) + mask_path = os.path.join(data_path, self.MASK_PATH) + self.img_list.extend(self.get_filenames(img_path)) + self.mask_list.extend(self.get_filenames(mask_path)) + + def __len__(self): + return len(self.img_list) + + def __getitem__(self, idx, transpose=True): + + # load image + img = np.load(self.img_list[idx]) + + # load mask + mask = np.load(self.mask_list[idx]) + if len(mask.shape) > 2: + mask = np.argmax(mask, axis=-1) + + # perform transformations + if self.transform is not None: + img = self.transform(img) + + return img, mask + + def get_filenames(self, path): + """ + Returns a list of absolute paths to images inside given `path` + """ + files_list = [] + for filename in sorted(os.listdir(path)): + files_list.append(os.path.join(path, filename)) + return files_list diff --git a/pytorch-caney/pytorch_caney/data/datasets/modis_lc_five_dataset.py b/pytorch-caney/pytorch_caney/data/datasets/modis_lc_five_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c8948a065d5fdbb9d36ac196001544f67b00a683 --- /dev/null +++ b/pytorch-caney/pytorch_caney/data/datasets/modis_lc_five_dataset.py @@ -0,0 +1,79 @@ +import os +from torch.utils.data import Dataset + +import numpy as np +import random + + +class MODISLCFiveDataset(Dataset): + """ + MODIS Landcover five-class pytorch fine-tuning dataset + """ + + IMAGE_PATH = os.path.join("images") + MASK_PATH = os.path.join("labels") + + def __init__( + self, + data_paths: list, + split: str, + img_size: tuple = (224, 224), + transform=None, + ): + self.img_size = img_size + self.transform = transform + self.split = split + self.data_paths = data_paths + self.img_list = [] + self.mask_list = [] + for data_path in data_paths: + img_path = os.path.join(data_path, self.IMAGE_PATH) + mask_path = os.path.join(data_path, self.MASK_PATH) + self.img_list.extend(self.get_filenames(img_path)) + self.mask_list.extend(self.get_filenames(mask_path)) + # Split between train and valid set (80/20) + + random_inst = random.Random(12345) # for repeatability + n_items = len(self.img_list) + print(f'Found {n_items} possible patches to use') + range_n_items = range(n_items) + range_n_items = random_inst.sample(range_n_items, int(n_items*0.5)) + idxs = set(random_inst.sample(range_n_items, len(range_n_items) // 5)) + total_idxs = set(range_n_items) + if split == 'train': + idxs = total_idxs - idxs + print(f'> Using {len(idxs)} patches for this dataset ({split})') + self.img_list = [self.img_list[i] for i in idxs] + self.mask_list = [self.mask_list[i] for i in idxs] + print(f'>> {split}: {len(self.img_list)}') + + def __len__(self): + return len(self.img_list) + + def __getitem__(self, idx, transpose=True): + + # load image + img = np.load(self.img_list[idx]) + + img = np.clip(img, 0, 1.0) + + # load mask + mask = np.load(self.mask_list[idx]) + + mask = np.argmax(mask, axis=-1) + + mask = mask-1 + + # perform transformations + img = self.transform(img) + + return img, mask + + def get_filenames(self, path): + """ + Returns a list of absolute paths to images inside given `path` + """ + files_list = [] + for filename in sorted(os.listdir(path)): + files_list.append(os.path.join(path, filename)) + return files_list diff --git a/pytorch-caney/pytorch_caney/data/datasets/modis_lc_nine_dataset.py b/pytorch-caney/pytorch_caney/data/datasets/modis_lc_nine_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e601b6a68744ab1c8126dec06a8476ff27b37f49 --- /dev/null +++ b/pytorch-caney/pytorch_caney/data/datasets/modis_lc_nine_dataset.py @@ -0,0 +1,79 @@ +import os +import random + +import numpy as np + +from torch.utils.data import Dataset + + +class MODISLCNineDataset(Dataset): + """ + MODIS Landcover nine-class pytorch fine-tuning dataset + """ + IMAGE_PATH = os.path.join("images") + MASK_PATH = os.path.join("labels") + + def __init__( + self, + data_paths: list, + split: str, + img_size: tuple = (224, 224), + transform=None, + ): + self.img_size = img_size + self.transform = transform + self.split = split + self.data_paths = data_paths + self.img_list = [] + self.mask_list = [] + for data_path in data_paths: + img_path = os.path.join(data_path, self.IMAGE_PATH) + mask_path = os.path.join(data_path, self.MASK_PATH) + self.img_list.extend(self.get_filenames(img_path)) + self.mask_list.extend(self.get_filenames(mask_path)) + # Split between train and valid set (80/20) + + random_inst = random.Random(12345) # for repeatability + n_items = len(self.img_list) + print(f'Found {n_items} possible patches to use') + range_n_items = range(n_items) + range_n_items = random_inst.sample(range_n_items, int(n_items*0.5)) + idxs = set(random_inst.sample(range_n_items, len(range_n_items) // 5)) + total_idxs = set(range_n_items) + if split == 'train': + idxs = total_idxs - idxs + print(f'> Using {len(idxs)} patches for this dataset ({split})') + self.img_list = [self.img_list[i] for i in idxs] + self.mask_list = [self.mask_list[i] for i in idxs] + print(f'>> {split}: {len(self.img_list)}') + + def __len__(self): + return len(self.img_list) + + def __getitem__(self, idx, transpose=True): + + # load image + img = np.load(self.img_list[idx]) + + img = np.clip(img, 0, 1.0) + + # load mask + mask = np.load(self.mask_list[idx]) + + mask = np.argmax(mask, axis=-1) + + mask = mask-1 + + # perform transformations + img = self.transform(img) + + return img, mask + + def get_filenames(self, path): + """ + Returns a list of absolute paths to images inside given `path` + """ + files_list = [] + for filename in sorted(os.listdir(path)): + files_list.append(os.path.join(path, filename)) + return files_list diff --git a/pytorch-caney/pytorch_caney/data/datasets/object_dataset.py b/pytorch-caney/pytorch_caney/data/datasets/object_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/data/datasets/segmentation_dataset.py b/pytorch-caney/pytorch_caney/data/datasets/segmentation_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..d81c757c7854d9f3c09f890ec85cb33a8482920c --- /dev/null +++ b/pytorch-caney/pytorch_caney/data/datasets/segmentation_dataset.py @@ -0,0 +1,284 @@ +import os +import logging +from glob import glob +from pathlib import Path +from typing import Optional, Union + +import torch +import numpy as np +from torch.utils.data import Dataset +from torch.utils.dlpack import from_dlpack + +import xarray as xr +from terragpu.engine import array_module, df_module + +import terragpu.ai.preprocessing as preprocessing + +xp = array_module() +xf = df_module() + + +class PLSegmentationDataset(Dataset): + + def __init__( + self, + images_regex: Optional[str] = None, + labels_regex: Optional[str] = None, + dataset_dir: Optional[str] = None, + generate_dataset: bool = False, + tile_size: int = 256, + max_patches: Union[float, int] = 100, + augment: bool = True, + chunks: dict = {'band': 1, 'x': 2048, 'y': 2048}, + input_bands: list = ['CB', 'B', 'G', 'Y', 'R', 'RE', 'N1', 'N2'], + output_bands: list = ['B', 'G', 'R'], + seed: int = 24, + normalize: bool = True, + pytorch: bool = True): + + super().__init__() + + # Dataset metadata + self.input_bands = input_bands + self.output_bands = output_bands + self.chunks = chunks + self.tile_size = tile_size + self.seed = seed + self.max_patches = max_patches + + # Preprocessing metadata + self.generate_dataset = generate_dataset + self.normalize = normalize + + # Validate several input sources + assert dataset_dir is not None, \ + f'dataset_dir: {dataset_dir} does not exist.' + + # Setup directories structure + self.dataset_dir = dataset_dir # where to store dataset + self.images_dir = os.path.join(self.dataset_dir, 'images') + self.labels_dir = os.path.join(self.dataset_dir, 'labels') + + if self.generate_dataset: + + logging.info(f"Starting to prepare dataset: {self.dataset_dir}") + # Assert images_dir and labels_dir to be not None + self.images_regex = images_regex # images location + self.labels_regex = labels_regex # labels location + + # Create directories to store dataset + os.makedirs(self.images_dir, exist_ok=True) + os.makedirs(self.labels_dir, exist_ok=True) + + self.prepare_data() + + assert os.path.exists(self.images_dir), \ + f'{self.images_dir} does not exist. Make sure prepare_data: true.' + assert os.path.exists(self.labels_dir), \ + f'{self.labels_dir} does not exist. Make sure prepare_data: true.' + + self.files = self.get_filenames() + self.augment = augment + self.pytorch = pytorch + + # ------------------------------------------------------------------------- + # Dataset methods + # ------------------------------------------------------------------------- + def __len__(self): + return len(self.files) + + def __repr__(self): + s = 'Dataset class with {} files'.format(self.__len__()) + return s + + def __getitem__(self, idx): + + idx = idx % len(self.files) + x, y = self.open_image(idx), self.open_mask(idx) + + if self.augment: + x, y = self.transform(x, y) + return x, y + + def transform(self, x, y): + + if xp.random.random_sample() > 0.5: # flip left and right + x = torch.fliplr(x) + y = torch.fliplr(y) + if xp.random.random_sample() > 0.5: # reverse second dimension + x = torch.flipud(x) + y = torch.flipud(y) + if xp.random.random_sample() > 0.5: # rotate 90 degrees + x = torch.rot90(x, k=1, dims=[1, 2]) + y = torch.rot90(y, k=1, dims=[0, 1]) + if xp.random.random_sample() > 0.5: # rotate 180 degrees + x = torch.rot90(x, k=2, dims=[1, 2]) + y = torch.rot90(y, k=2, dims=[0, 1]) + if xp.random.random_sample() > 0.5: # rotate 270 degrees + x = torch.rot90(x, k=3, dims=[1, 2]) + y = torch.rot90(y, k=3, dims=[0, 1]) + + # standardize 0.70, 0.30 + # if np.random.random_sample() > 0.70: + # image = preprocess.standardizeLocalCalcTensor(image, means, stds) + # else: + # image = preprocess.standardizeGlobalCalcTensor(image) + return x, y + + # ------------------------------------------------------------------------- + # preprocess methods + # ------------------------------------------------------------------------- + def prepare_data(self): + + logging.info("Preparing dataset...") + images_list = sorted(glob(self.images_regex)) + labels_list = sorted(glob(self.labels_regex)) + + for image, label in zip(images_list, labels_list): + + # Read imagery from disk and process both image and mask + filename = Path(image).stem + image = xr.open_rasterio(image, chunks=self.chunks).load() + label = xr.open_rasterio(label, chunks=self.chunks).values + + # Modify bands if necessary - in a future version, add indices + image = preprocessing.modify_bands( + img=image, input_bands=self.input_bands, + output_bands=self.output_bands) + + # Asarray option to force array type + image = xp.asarray(image.values) + label = xp.asarray(label) + + # Move from chw to hwc, squeze mask if required + image = xp.moveaxis(image, 0, -1).astype(np.int16) + label = xp.squeeze(label) if len(label.shape) != 2 else label + logging.info(f'Label classes from image: {xp.unique(label)}') + + # Generate dataset tiles + image_tiles, label_tiles = preprocessing.gen_random_tiles( + image=image, label=label, tile_size=self.tile_size, + max_patches=self.max_patches, seed=self.seed) + logging.info(f"Tiles: {image_tiles.shape}, {label_tiles.shape}") + + # Save to disk + for id in range(image_tiles.shape[0]): + xp.save( + os.path.join(self.images_dir, f'{filename}_{id}.npy'), + image_tiles[id, :, :, :]) + xp.save( + os.path.join(self.labels_dir, f'{filename}_{id}.npy'), + label_tiles[id, :, :]) + return + + # ------------------------------------------------------------------------- + # dataset methods + # ------------------------------------------------------------------------- + def list_files(self, files_list: list = []): + + for i in os.listdir(self.images_dir): + files_list.append( + { + 'image': os.path.join(self.images_dir, i), + 'label': os.path.join(self.labels_dir, i) + } + ) + return files_list + + def open_image(self, idx: int, invert: bool = True): + # image = imread(self.files[idx]['image']) + image = xp.load(self.files[idx]['image'], allow_pickle=False) + image = image.transpose((2, 0, 1)) if invert else image + image = ( + image / xp.iinfo(image.dtype).max) if self.normalize else image + return from_dlpack(image.toDlpack()) # .to(torch.float32) + + def open_mask(self, idx: int, add_dims: bool = False): + # mask = imread(self.files[idx]['label']) + mask = xp.load(self.files[idx]['label'], allow_pickle=False) + mask = xp.expand_dims(mask, 0) if add_dims else mask + return from_dlpack(mask.toDlpack()) # .to(torch.torch.int64) + + +class SegmentationDataset(Dataset): + + def __init__( + self, dataset_dir, pytorch=True, augment=True): + + super().__init__() + + self.files: list = self.list_files(dataset_dir) + self.augment: bool = augment + self.pytorch: bool = pytorch + self.invert: bool = True + self.normalize: bool = True + self.standardize: bool = True + + # ------------------------------------------------------------------------- + # Common methods + # ------------------------------------------------------------------------- + def __len__(self): + return len(self.files) + + def __repr__(self): + s = 'Dataset class with {} files'.format(self.__len__()) + return s + + def __getitem__(self, idx): + + # get data + x = self.open_image(idx) + y = self.open_mask(idx) + + # augment the data + if self.augment: + + if xp.random.random_sample() > 0.5: # flip left and right + x = torch.fliplr(x) + y = torch.fliplr(y) + if xp.random.random_sample() > 0.5: # reverse second dimension + x = torch.flipud(x) + y = torch.flipud(y) + if xp.random.random_sample() > 0.5: # rotate 90 degrees + x = torch.rot90(x, k=1, dims=[1, 2]) + y = torch.rot90(y, k=1, dims=[0, 1]) + if xp.random.random_sample() > 0.5: # rotate 180 degrees + x = torch.rot90(x, k=2, dims=[1, 2]) + y = torch.rot90(y, k=2, dims=[0, 1]) + if xp.random.random_sample() > 0.5: # rotate 270 degrees + x = torch.rot90(x, k=3, dims=[1, 2]) + y = torch.rot90(y, k=3, dims=[0, 1]) + + return x, y + + # ------------------------------------------------------------------------- + # IO methods + # ------------------------------------------------------------------------- + def get_filenames(self, dataset_dir: str, files_list: list = []): + + images_dir = os.path.join(dataset_dir, 'images') + labels_dir = os.path.join(dataset_dir, 'labels') + + for i in os.listdir(images_dir): + files_list.append( + { + 'image': os.path.join(images_dir, i), + 'label': os.path.join(labels_dir, i) + } + ) + return files_list + + def open_image(self, idx: int): + image = xp.load(self.files[idx]['image'], allow_pickle=False) + if self.invert: + image = image.transpose((2, 0, 1)) + if self.normalize: + image = (image / xp.iinfo(image.dtype).max) + if self.standardize: + image = preprocessing.standardize_local(image) + return from_dlpack(image.toDlpack()).float() + + def open_mask(self, idx: int, add_dims: bool = False): + mask = xp.load(self.files[idx]['label'], allow_pickle=False) + mask = xp.expand_dims(mask, 0) if add_dims else mask + return from_dlpack(mask.toDlpack()).long() diff --git a/pytorch-caney/pytorch_caney/data/datasets/simmim_modis_dataset.py b/pytorch-caney/pytorch_caney/data/datasets/simmim_modis_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ff69735ab802df3819efbc6d830e5e62374d14b2 --- /dev/null +++ b/pytorch-caney/pytorch_caney/data/datasets/simmim_modis_dataset.py @@ -0,0 +1,90 @@ +from ..utils import SimmimMaskGenerator + +import os +import numpy as np + +from torch.utils.data import Dataset + + +class MODISDataset(Dataset): + """ + MODIS MOD09GA pre-training dataset + """ + IMAGE_PATH = os.path.join("images") + + def __init__( + self, + config, + data_paths: list, + split: str, + img_size: tuple = (192, 192), + transform=None, + ): + + self.config = config + + self.img_size = img_size + + self.transform = transform + + self.split = split + + self.data_paths = data_paths + + self.img_list = [] + + for data_path in data_paths: + + img_path = os.path.join(data_path, self.IMAGE_PATH) + + self.img_list.extend(self.get_filenames(img_path)) + + n_items = len(self.img_list) + + print(f'> Found {n_items} patches for this dataset ({split})') + + if config.MODEL.TYPE in ['swin', 'swinv2']: + + model_patch_size = config.MODEL.SWINV2.PATCH_SIZE + + else: + + raise NotImplementedError + + self.mask_generator = SimmimMaskGenerator( + input_size=config.DATA.IMG_SIZE, + mask_patch_size=config.DATA.MASK_PATCH_SIZE, + model_patch_size=model_patch_size, + mask_ratio=config.DATA.MASK_RATIO, + ) + + def __len__(self): + + return len(self.img_list) + + def __getitem__(self, idx, transpose=True): + + # load image + img = np.load(self.img_list[idx]) + + img = np.clip(img, 0, 1.0) + + # perform transformations + img = self.transform(img) + + mask = self.mask_generator() + + return img, mask + + def get_filenames(self, path): + """ + Returns a list of absolute paths to images inside given `path` + """ + + files_list = [] + + for filename in sorted(os.listdir(path)): + + files_list.append(os.path.join(path, filename)) + + return files_list diff --git a/pytorch-caney/pytorch_caney/data/transforms.py b/pytorch-caney/pytorch_caney/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a71b592d0f6c1d7d63e484d65e736eefd8aa5d --- /dev/null +++ b/pytorch-caney/pytorch_caney/data/transforms.py @@ -0,0 +1,64 @@ +from .utils import RandomResizedCropNP +from .utils import SimmimMaskGenerator + +import torchvision.transforms as T + + +class SimmimTransform: + """ + torchvision transform which transforms the input imagery into + addition to generating a MiM mask + """ + + def __init__(self, config): + + self.transform_img = \ + T.Compose([ + RandomResizedCropNP(scale=(0.67, 1.), + ratio=(3. / 4., 4. / 3.)), + T.ToTensor(), + T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), + ]) + + if config.MODEL.TYPE in ['swin', 'swinv2']: + + model_patch_size = config.MODEL.SWINV2.PATCH_SIZE + + else: + + raise NotImplementedError + + self.mask_generator = SimmimMaskGenerator( + input_size=config.DATA.IMG_SIZE, + mask_patch_size=config.DATA.MASK_PATCH_SIZE, + model_patch_size=model_patch_size, + mask_ratio=config.DATA.MASK_RATIO, + ) + + def __call__(self, img): + + img = self.transform_img(img) + mask = self.mask_generator() + + return img, mask + + +class TensorResizeTransform: + """ + torchvision transform which transforms the input imagery into + addition to generating a MiM mask + """ + + def __init__(self, config): + + self.transform_img = \ + T.Compose([ + T.ToTensor(), + T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), + ]) + + def __call__(self, img): + + img = self.transform_img(img) + + return img diff --git a/pytorch-caney/pytorch_caney/data/utils.py b/pytorch-caney/pytorch_caney/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..86d9555ae4849f6eabd7fa382a44948414e3d61d --- /dev/null +++ b/pytorch-caney/pytorch_caney/data/utils.py @@ -0,0 +1,116 @@ +import torch +import numpy as np + +from numba import njit + +# TRANSFORMS UTILS + + +class RandomResizedCropNP(object): + """ + Numpy implementation of RandomResizedCrop + """ + + def __init__(self, + scale=(0.08, 1.0), + ratio=(3.0/4.0, 4.0/3.0)): + + self.scale = scale + self.ratio = ratio + + def __call__(self, img): + + height, width = img.shape[:2] + area = height * width + + for _ in range(10): + target_area = np.random.uniform(*self.scale) * area + aspect_ratio = np.random.uniform(*self.ratio) + + w = int(round(np.sqrt(target_area * aspect_ratio))) + h = int(round(np.sqrt(target_area / aspect_ratio))) + + if np.random.random() < 0.5: + w, h = h, w + + if w <= width and h <= height: + x1 = np.random.randint(0, width - w + 1) + y1 = np.random.randint(0, height - h + 1) + cropped = img[y1:y1+h, x1:x1+w, :] + cropped = np.moveaxis(cropped, -1, 0) + cropped_resized = torch.nn.functional.interpolate( + torch.from_numpy(cropped).unsqueeze(0), + size=height, + mode='bicubic', + align_corners=False) + cropped_squeezed_numpy = cropped_resized.squeeze().numpy() + cropped_squeezed_numpy = np.moveaxis( + cropped_squeezed_numpy, 0, -1) + return cropped_squeezed_numpy + + # if crop was not successful after 10 attempts, use center crop + w = min(width, height) + x1 = (width - w) // 2 + y1 = (height - w) // 2 + cropped = img[y1:y1+w, x1:x1+w, :] + cropped = np.moveaxis(cropped, -1, 0) + cropped_resized = torch.nn.functional.interpolate(torch.from_numpy( + cropped).unsqueeze(0), + size=height, + mode='bicubic', + align_corners=False) + cropped_squeezed_numpy = cropped_resized.squeeze().numpy() + cropped_squeezed_numpy = np.moveaxis(cropped_squeezed_numpy, 0, -1) + return cropped_squeezed_numpy + + +# MASKING + +class SimmimMaskGenerator: + """ + Generates the masks for masked-image-modeling + """ + def __init__(self, + input_size=192, + mask_patch_size=32, + model_patch_size=4, + mask_ratio=0.6): + self.input_size = input_size + self.mask_patch_size = mask_patch_size + self.model_patch_size = model_patch_size + self.mask_ratio = mask_ratio + + assert self.input_size % self.mask_patch_size == 0 + assert self.mask_patch_size % self.model_patch_size == 0 + + self.rand_size = self.input_size // self.mask_patch_size + self.scale = self.mask_patch_size // self.model_patch_size + + self.token_count = self.rand_size ** 2 + self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) + + def __call__(self): + mask = make_simmim_mask(self.token_count, self.mask_count, + self.rand_size, self.scale) + mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) + return mask + + +@njit() +def make_simmim_mask(token_count, mask_count, rand_size, scale): + """JIT-compiled random mask generation + + Args: + token_count + mask_count + rand_size + scale + + Returns: + mask + """ + mask_idx = np.random.permutation(token_count)[:mask_count] + mask = np.zeros(token_count, dtype=np.int64) + mask[mask_idx] = 1 + mask = mask.reshape((rand_size, rand_size)) + return mask diff --git a/pytorch-caney/pytorch_caney/inference.py b/pytorch-caney/pytorch_caney/inference.py new file mode 100755 index 0000000000000000000000000000000000000000..5abad4a03fbcc219275a78f61254cff5503e711e --- /dev/null +++ b/pytorch-caney/pytorch_caney/inference.py @@ -0,0 +1,382 @@ +import logging +import math +import numpy as np + +import torch + +from tiler import Tiler, Merger + +from pytorch_caney.processing import normalize +from pytorch_caney.processing import global_standardization +from pytorch_caney.processing import local_standardization +from pytorch_caney.processing import standardize_image + +__author__ = "Jordan A Caraballo-Vega, Science Data Processing Branch" +__email__ = "jordan.a.caraballo-vega@nasa.gov" +__status__ = "Production" + +# --------------------------------------------------------------------------- +# module inference +# +# Data segmentation and prediction functions. +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Module Methods +# --------------------------------------------------------------------------- +def sliding_window_tiler_multiclass( + xraster, + model, + n_classes: int, + img_size: int, + pad_style: str = 'reflect', + overlap: float = 0.50, + constant_value: int = 600, + batch_size: int = 1024, + threshold: float = 0.50, + standardization: str = None, + mean=None, + std=None, + normalize: float = 1.0, + rescale: str = None, + window: str = 'triang', # 'overlap-tile' + probability_map: bool = False + ): + """ + Sliding window using tiler. + """ + + tile_channels = xraster.shape[-1] # model.layers[0].input_shape[0][-1] + print(f'Standardizing: {standardization}') + # n_classes = out of the output layer, output_shape + + tiler_image = Tiler( + data_shape=xraster.shape, + tile_shape=(img_size, img_size, tile_channels), + channel_dimension=-1, + overlap=overlap, + mode=pad_style, + constant_value=constant_value + ) + + # Define the tiler and merger based on the output size of the prediction + tiler_mask = Tiler( + data_shape=(xraster.shape[0], xraster.shape[1], n_classes), + tile_shape=(img_size, img_size, n_classes), + channel_dimension=-1, + overlap=overlap, + mode=pad_style, + constant_value=constant_value + ) + + merger = Merger(tiler=tiler_mask, window=window) + # xraster = normalize_image(xraster, normalize) + + # Iterate over the data in batches + for batch_id, batch_i in tiler_image(xraster, batch_size=batch_size): + + # Standardize + batch = batch_i.copy() + + if standardization is not None: + for item in range(batch.shape[0]): + batch[item, :, :, :] = standardize_image( + batch[item, :, :, :], standardization, mean, std) + + input_batch = batch.astype('float32') + input_batch_tensor = torch.from_numpy(input_batch) + input_batch_tensor = input_batch_tensor.transpose(-1, 1) + # input_batch_tensor = input_batch_tensor.cuda(non_blocking=True) + with torch.no_grad(): + y_batch = model(input_batch_tensor) + y_batch = y_batch.transpose(1, -1) # .cpu().numpy() + merger.add_batch(batch_id, batch_size, y_batch) + + prediction = merger.merge(unpad=True) + + if not probability_map: + if prediction.shape[-1] > 1: + prediction = np.argmax(prediction, axis=-1) + else: + prediction = np.squeeze( + np.where(prediction > threshold, 1, 0).astype(np.int16) + ) + else: + prediction = np.squeeze(prediction) + return prediction + + +# --------------------------- Segmentation Functions ----------------------- # + +def segment(image, model='model.h5', tile_size=256, channels=6, + norm_data=[], bsize=8): + """ + Applies a semantic segmentation model to an image. Ideal for non-scene + imagery. Leaves artifacts in boundaries if no post-processing is done. + :param image: image to classify (numpy array) + :param model: loaded model object + :param tile_size: tile size of patches + :param channels: number of channels + :param norm_data: numpy array with mean and std data + :param bsize: number of patches to predict at the same time + return numpy array with classified mask + """ + # Create blank array to store predicted label + seg = np.zeros((image.shape[0], image.shape[1])) + for i in range(0, image.shape[0], int(tile_size)): + for j in range(0, image.shape[1], int(tile_size)): + # If edge of tile beyond image boundary, shift it to boundary + if i + tile_size > image.shape[0]: + i = image.shape[0] - tile_size + if j + tile_size > image.shape[1]: + j = image.shape[1] - tile_size + + # Extract and normalise tile + tile = normalize( + image[i: i + tile_size, j: j + tile_size, :].astype(float), + norm_data + ) + out = model.predict( + tile.reshape( + (1, tile.shape[0], tile.shape[1], tile.shape[2]) + ).astype(float), + batch_size=4 + ) + out = out.argmax(axis=3) # get max prediction for pixel in classes + out = out.reshape(tile_size, tile_size) # reshape to tile size + seg[i: i + tile_size, j: j + tile_size] = out + return seg + + +def segment_binary(image, model='model.h5', norm_data=[], + tile_size=256, channels=6, bsize=8 + ): + """ + Applies binary semantic segmentation model to an image. Ideal for non-scene + imagery. Leaves artifacts in boundaries if no post-processing is done. + :param image: image to classify (numpy array) + :param model: loaded model object + :param tile_size: tile size of patches + :param channels: number of channels + :param norm_data: numpy array with mean and std data + return numpy array with classified mask + """ + # Create blank array to store predicted label + seg = np.zeros((image.shape[0], image.shape[1])) + for i in range(0, image.shape[0], int(tile_size)): + for j in range(0, image.shape[1], int(tile_size)): + # If edge of tile beyond image boundary, shift it to boundary + if i + tile_size > image.shape[0]: + i = image.shape[0] - tile_size + if j + tile_size > image.shape[1]: + j = image.shape[1] - tile_size + + # Extract and normalise tile + tile = normalize( + image[i:i + tile_size, j:j + tile_size, :].astype(float), + norm_data + ) + out = model.predict( + tile.reshape( + (1, tile.shape[0], tile.shape[1], tile.shape[2]) + ).astype(float), + batch_size=bsize + ) + out[out >= 0.5] = 1 + out[out < 0.5] = 0 + out = out.reshape(tile_size, tile_size) # reshape to tile size + seg[i:i + tile_size, j:j + tile_size] = out + return seg + + +def pad_image(img, target_size): + """ + Pad an image up to the target size. + """ + rows_missing = target_size - img.shape[0] + cols_missing = target_size - img.shape[1] + padded_img = np.pad( + img, ((0, rows_missing), (0, cols_missing), (0, 0)), 'constant' + ) + return padded_img + + +def predict_sliding(image, model='', stand_method='local', + stand_strategy='per-batch', stand_data=[], + tile_size=256, nclasses=6, overlap=0.25, spline=[] + ): + """ + Predict on tiles of exactly the network input shape. + This way nothing gets squeezed. + """ + model.eval() + stride = math.ceil(tile_size * (1 - overlap)) + tile_rows = max( + int(math.ceil((image.shape[0] - tile_size) / stride) + 1), 1 + ) # strided convolution formula + tile_cols = max( + int(math.ceil((image.shape[1] - tile_size) / stride) + 1), 1 + ) + logging.info("Need %i x %i prediction tiles @ stride %i px" % + (tile_cols, tile_rows, stride) + ) + + full_probs = np.zeros((image.shape[0], image.shape[1], nclasses)) + count_predictions = np.zeros((image.shape[0], image.shape[1], nclasses)) + tile_counter = 0 + for row in range(tile_rows): + for col in range(tile_cols): + x1 = int(col * stride) + y1 = int(row * stride) + x2 = min(x1 + tile_size, image.shape[1]) + y2 = min(y1 + tile_size, image.shape[0]) + x1 = max(int(x2 - tile_size), 0) + y1 = max(int(y2 - tile_size), 0) + + img = image[y1:y2, x1:x2] + padded_img = pad_image(img, tile_size) + tile_counter += 1 + + padded_img = np.expand_dims(padded_img, 0) + + if stand_method == 'local': + imgn = local_standardization( + padded_img, ndata=stand_data, strategy=stand_strategy + ) + elif stand_method == 'global': + imgn = global_standardization( + padded_img, strategy=stand_strategy + ) + else: + imgn = padded_img + + imgn = imgn.astype('float32') + imgn_tensor = torch.from_numpy(imgn) + imgn_tensor = imgn_tensor.transpose(-1, 1) + with torch.no_grad(): + padded_prediction = model(imgn_tensor) + # if padded_prediction.shape[1] > 1: + # padded_prediction = np.argmax(padded_prediction, axis=1) + padded_prediction = np.squeeze(padded_prediction) + padded_prediction = padded_prediction.transpose(0, -1).numpy() + prediction = padded_prediction[0:img.shape[0], 0:img.shape[1], :] + count_predictions[y1:y2, x1:x2] += 1 + full_probs[y1:y2, x1:x2] += prediction # * spline + # average the predictions in the overlapping regions + full_probs /= count_predictions + return full_probs + + +def predict_sliding_binary(image, model='model.h5', tile_size=256, + nclasses=6, overlap=1/3, norm_data=[] + ): + """ + Predict on tiles of exactly the network input shape. + This way nothing gets squeezed. + """ + stride = math.ceil(tile_size * (1 - overlap)) + tile_rows = max( + int(math.ceil((image.shape[0] - tile_size) / stride) + 1), 1 + ) # strided convolution formula + tile_cols = max( + int(math.ceil((image.shape[1] - tile_size) / stride) + 1), 1 + ) + logging.info("Need %i x %i prediction tiles @ stride %i px" % + (tile_cols, tile_rows, stride) + ) + full_probs = np.zeros((image.shape[0], image.shape[1], nclasses)) + count_predictions = np.zeros((image.shape[0], image.shape[1], nclasses)) + tile_counter = 0 + for row in range(tile_rows): + for col in range(tile_cols): + x1 = int(col * stride) + y1 = int(row * stride) + x2 = min(x1 + tile_size, image.shape[1]) + y2 = min(y1 + tile_size, image.shape[0]) + x1 = max(int(x2 - tile_size), 0) + y1 = max(int(y2 - tile_size), 0) + + img = image[y1:y2, x1:x2] + padded_img = pad_image(img, tile_size) + tile_counter += 1 + + imgn = normalize(padded_img, norm_data) + imgn = imgn.astype('float32') + padded_prediction = model.predict(np.expand_dims(imgn, 0))[0] + prediction = padded_prediction[0:img.shape[0], 0:img.shape[1], :] + count_predictions[y1:y2, x1:x2] += 1 + full_probs[y1:y2, x1:x2] += prediction + # average the predictions in the overlapping regions + full_probs /= count_predictions + full_probs[full_probs >= 0.8] = 1 + full_probs[full_probs < 0.8] = 0 + return full_probs.reshape((image.shape[0], image.shape[1])) + + +def predict_windowing(x, model, stand_method='local', + stand_strategy='per-batch', stand_data=[], + patch_sz=160, n_classes=5, b_size=128, spline=[] + ): + img_height = x.shape[0] + img_width = x.shape[1] + n_channels = x.shape[2] + # make extended img so that it contains integer number of patches + npatches_vertical = math.ceil(img_height / patch_sz) + npatches_horizontal = math.ceil(img_width / patch_sz) + extended_height = patch_sz * npatches_vertical + extended_width = patch_sz * npatches_horizontal + ext_x = np.zeros( + shape=(extended_height, extended_width, n_channels), dtype=np.float32 + ) + # fill extended image with mirrors: + ext_x[:img_height, :img_width, :] = x + for i in range(img_height, extended_height): + ext_x[i, :, :] = ext_x[2 * img_height - i - 1, :, :] + for j in range(img_width, extended_width): + ext_x[:, j, :] = ext_x[:, 2 * img_width - j - 1, :] + + # now we assemble all patches in one array + patches_list = [] + for i in range(0, npatches_vertical): + for j in range(0, npatches_horizontal): + x0, x1 = i * patch_sz, (i + 1) * patch_sz + y0, y1 = j * patch_sz, (j + 1) * patch_sz + patches_list.append(ext_x[x0:x1, y0:y1, :]) + patches_array = np.asarray(patches_list) + + # normalization(patches_array, ndata) + + if stand_method == 'local': # apply local zero center standardization + patches_array = local_standardization( + patches_array, ndata=stand_data, strategy=stand_strategy + ) + elif stand_method == 'global': # apply global zero center standardization + patches_array = global_standardization( + patches_array, strategy=stand_strategy + ) + + # predictions: + patches_predict = model.predict(patches_array, batch_size=b_size) + prediction = np.zeros( + shape=(extended_height, extended_width, n_classes), dtype=np.float32 + ) + logging.info("prediction shape: ", prediction.shape) + for k in range(patches_predict.shape[0]): + i = k // npatches_horizontal + j = k % npatches_horizontal + x0, x1 = i * patch_sz, (i + 1) * patch_sz + y0, y1 = j * patch_sz, (j + 1) * patch_sz + prediction[x0:x1, y0:y1, :] = patches_predict[k, :, :, :] * spline + return prediction[:img_height, :img_width, :] + + +# ------------------------------------------------------------------------------- +# module model Unit Tests +# ------------------------------------------------------------------------------- + +if __name__ == "__main__": + + logging.basicConfig(level=logging.INFO) + + # Add unit tests here diff --git a/pytorch-caney/pytorch_caney/loss/build.py b/pytorch-caney/pytorch_caney/loss/build.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1cc1629e86372db48709de57179ff3c9154815 --- /dev/null +++ b/pytorch-caney/pytorch_caney/loss/build.py @@ -0,0 +1,64 @@ +from segmentation_models_pytorch.losses import TverskyLoss + + +LOSSES = { + 'tversky': TverskyLoss, +} + + +def get_loss_from_dict(loss_name, config): + """Gets the proper loss given a loss name. + + Args: + loss_name (str): name of the loss + config: config object + + Raises: + KeyError: thrown if loss key is not present in dict + + Returns: + loss: pytorch loss + """ + + try: + + loss_to_use = LOSSES[loss_name] + + except KeyError: + + error_msg = f"{loss_name} is not an implemented loss" + + error_msg = f"{error_msg}. Available loss functions: {LOSSES.keys()}" + + raise KeyError(error_msg) + + if loss_name == 'tversky': + loss = loss_to_use(mode=config.LOSS.MODE, + classes=config.LOSS.CLASSES, + log_loss=config.LOSS.LOG, + from_logits=config.LOSS.LOGITS, + smooth=config.LOSS.SMOOTH, + ignore_index=config.LOSS.IGNORE_INDEX, + eps=config.LOSS.EPS, + alpha=config.LOSS.ALPHA, + beta=config.LOSS.BETA, + gamma=config.LOSS.GAMMA) + return loss + + +def build_loss(config): + """ + Builds the loss function given a configuration object. + + Args: + config: config object + + Returns: + loss_to_use: pytorch loss function + """ + + loss_name = config.LOSS.NAME + + loss_to_use = get_loss_from_dict(loss_name, config) + + return loss_to_use diff --git a/pytorch-caney/pytorch_caney/loss/utils.py b/pytorch-caney/pytorch_caney/loss/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..431980399bccbf1ca3adda06a7cf37cc80383034 --- /dev/null +++ b/pytorch-caney/pytorch_caney/loss/utils.py @@ -0,0 +1,26 @@ +import numpy as np + +import torch + + +# --- +# Adapted from +# https://github.com/qubvel/segmentation_models.pytorch \ +# /tree/master/segmentation_models_pytorch/losses +# --- +def to_tensor(x, dtype=None) -> torch.Tensor: + if isinstance(x, torch.Tensor): + if dtype is not None: + x = x.type(dtype) + return x + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if dtype is not None: + x = x.type(dtype) + return x + if isinstance(x, (list, tuple)): + x = np.array(x) + x = torch.from_numpy(x) + if dtype is not None: + x = x.type(dtype) + return x diff --git a/pytorch-caney/pytorch_caney/lr_scheduler.py b/pytorch-caney/pytorch_caney/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..cd693c9139fb1bcc5c97def4278c45aa84205e28 --- /dev/null +++ b/pytorch-caney/pytorch_caney/lr_scheduler.py @@ -0,0 +1,185 @@ +from bisect import bisect_right + +from timm.scheduler.cosine_lr import CosineLRScheduler +from timm.scheduler.step_lr import StepLRScheduler +from timm.scheduler.scheduler import Scheduler + +import torch +import torch.distributed as dist + + +def build_scheduler(config, optimizer, n_iter_per_epoch): + num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) + warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) + decay_steps = int( + config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) + multi_steps = [ + i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS] + + lr_scheduler = None + if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': + lr_scheduler = CosineLRScheduler( + optimizer, + t_initial=num_steps, + cycle_mul=1., + lr_min=config.TRAIN.MIN_LR, + warmup_lr_init=config.TRAIN.WARMUP_LR, + warmup_t=warmup_steps, + cycle_limit=1, + t_in_epochs=False, + ) + elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': + lr_scheduler = LinearLRScheduler( + optimizer, + t_initial=num_steps, + lr_min_rate=0.01, + warmup_lr_init=config.TRAIN.WARMUP_LR, + warmup_t=warmup_steps, + t_in_epochs=False, + ) + elif config.TRAIN.LR_SCHEDULER.NAME == 'step': + lr_scheduler = StepLRScheduler( + optimizer, + decay_t=decay_steps, + decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, + warmup_lr_init=config.TRAIN.WARMUP_LR, + warmup_t=warmup_steps, + t_in_epochs=False, + ) + elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep': + lr_scheduler = MultiStepLRScheduler( + optimizer, + milestones=multi_steps, + gamma=config.TRAIN.LR_SCHEDULER.GAMMA, + warmup_lr_init=config.TRAIN.WARMUP_LR, + warmup_t=warmup_steps, + t_in_epochs=False, + ) + + return lr_scheduler + + +class LinearLRScheduler(Scheduler): + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lr_min_rate: float, + warmup_t=0, + warmup_lr_init=0., + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, + noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + self.t_initial = t_initial + self.lr_min_rate = lr_min_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / + self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + t = t - self.warmup_t + total_t = self.t_initial - self.warmup_t + lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) + for v in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + +class MultiStepLRScheduler(Scheduler): + def __init__(self, optimizer: torch.optim.Optimizer, + milestones, gamma=0.1, warmup_t=0, + warmup_lr_init=0, t_in_epochs=True) -> None: + super().__init__(optimizer, param_group_field="lr") + + self.milestones = milestones + self.gamma = gamma + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / + self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + assert self.warmup_t <= min(self.milestones) + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + lrs = [v * (self.gamma ** bisect_right(self.milestones, t)) + for v in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + +def setup_scaled_lr(config): + # linear scale the learning rate according to total batch size, + # may not be optimal + + batch_size = config.DATA.BATCH_SIZE + + world_size = dist.get_world_size() + + denom_const = 512.0 + + accumulation_steps = config.TRAIN.ACCUMULATION_STEPS + + linear_scaled_lr = config.TRAIN.BASE_LR * \ + batch_size * world_size / denom_const + + linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * \ + batch_size * world_size / denom_const + + linear_scaled_min_lr = config.TRAIN.MIN_LR * \ + batch_size * world_size / denom_const + + # gradient accumulation also need to scale the learning rate + if accumulation_steps > 1: + linear_scaled_lr = linear_scaled_lr * accumulation_steps + linear_scaled_warmup_lr = linear_scaled_warmup_lr * accumulation_steps + linear_scaled_min_lr = linear_scaled_min_lr * accumulation_steps + + return linear_scaled_lr, linear_scaled_warmup_lr, linear_scaled_min_lr diff --git a/pytorch-caney/pytorch_caney/metrics.py b/pytorch-caney/pytorch_caney/metrics.py new file mode 100755 index 0000000000000000000000000000000000000000..46794647cbf853f8618f025ba7c4b324288f70a4 --- /dev/null +++ b/pytorch-caney/pytorch_caney/metrics.py @@ -0,0 +1,80 @@ +import logging +from typing import List + +import torch +import numpy as np +from sklearn.metrics import accuracy_score +from sklearn.metrics import precision_score +from sklearn.metrics import recall_score + +__author__ = "Jordan A Caraballo-Vega, Science Data Processing Branch" +__email__ = "jordan.a.caraballo-vega@nasa.gov" +__status__ = "Production" + +# --------------------------------------------------------------------------- +# module metrics +# +# General functions to compute custom metrics. +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# Module Methods +# --------------------------------------------------------------------------- + +EPSILON = 1e-15 + + +# ------------------------------ Metric Functions -------------------------- # + +def iou_val(y_true, y_pred): + intersection = np.logical_and(y_true, y_pred) + union = np.logical_or(y_true, y_pred) + iou_score = np.sum(intersection) / np.sum(union) + return iou_score + + +def acc_val(y_true, y_pred): + return accuracy_score(y_true, y_pred) + + +def prec_val(y_true, y_pred): + return precision_score(y_true, y_pred, average='macro'), \ + precision_score(y_true, y_pred, average=None) + + +def recall_val(y_true, y_pred): + return recall_score(y_true, y_pred, average='macro'), \ + recall_score(y_true, y_pred, average=None) + + +def find_average(outputs: List, name: str) -> torch.Tensor: + if len(outputs[0][name].shape) == 0: + return torch.stack([x[name] for x in outputs]).mean() + return torch.cat([x[name] for x in outputs]).mean() + + +def binary_mean_iou( + logits: torch.Tensor, + targets: torch.Tensor + ) -> torch.Tensor: + + output = (logits > 0).int() + + if output.shape != targets.shape: + targets = torch.squeeze(targets, 1) + + intersection = (targets * output).sum() + + union = targets.sum() + output.sum() - intersection + + result = (intersection + EPSILON) / (union + EPSILON) + + return result + + +# ------------------------------------------------------------------------------- +# module metrics Unit Tests +# ------------------------------------------------------------------------------- +if __name__ == "__main__": + + logging.basicConfig(level=logging.INFO) diff --git a/pytorch-caney/pytorch_caney/models/__init__.py b/pytorch-caney/pytorch_caney/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/models/__pycache__/__init__.cpython-310.pyc b/pytorch-caney/pytorch_caney/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f53c195ff4a62673e9d3834ddd941c628e01c267 Binary files /dev/null and b/pytorch-caney/pytorch_caney/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/pytorch-caney/pytorch_caney/models/__pycache__/swinv2_model.cpython-310.pyc b/pytorch-caney/pytorch_caney/models/__pycache__/swinv2_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21a6149906c336d984f35e5d223b4da37dc49e67 Binary files /dev/null and b/pytorch-caney/pytorch_caney/models/__pycache__/swinv2_model.cpython-310.pyc differ diff --git a/pytorch-caney/pytorch_caney/models/build.py b/pytorch-caney/pytorch_caney/models/build.py new file mode 100644 index 0000000000000000000000000000000000000000..9ffc47c4b781ca29f3502d31af9dd70ddd3506f7 --- /dev/null +++ b/pytorch-caney/pytorch_caney/models/build.py @@ -0,0 +1,105 @@ +from .swinv2_model import SwinTransformerV2 +from .unet_swin_model import unet_swin +from .mim.mim import build_mim_model +from ..training.mim_utils import load_pretrained + +import logging + + +def build_model(config, + pretrain: bool = False, + pretrain_method: str = 'mim', + logger: logging.Logger = None): + """ + Given a config object, builds a pytorch model. + + Returns: + model: built model + """ + + if pretrain: + + if pretrain_method == 'mim': + model = build_mim_model(config) + return model + + encoder_architecture = config.MODEL.TYPE + decoder_architecture = config.MODEL.DECODER + + if encoder_architecture == 'swinv2': + + logger.info(f'Hit encoder only build, building {encoder_architecture}') + + window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES + + model = SwinTransformerV2( + img_size=config.DATA.IMG_SIZE, + patch_size=config.MODEL.SWINV2.PATCH_SIZE, + in_chans=config.MODEL.SWINV2.IN_CHANS, + num_classes=config.MODEL.NUM_CLASSES, + embed_dim=config.MODEL.SWINV2.EMBED_DIM, + depths=config.MODEL.SWINV2.DEPTHS, + num_heads=config.MODEL.SWINV2.NUM_HEADS, + window_size=config.MODEL.SWINV2.WINDOW_SIZE, + mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, + qkv_bias=config.MODEL.SWINV2.QKV_BIAS, + drop_rate=config.MODEL.DROP_RATE, + drop_path_rate=config.MODEL.DROP_PATH_RATE, + ape=config.MODEL.SWINV2.APE, + patch_norm=config.MODEL.SWINV2.PATCH_NORM, + use_checkpoint=config.TRAIN.USE_CHECKPOINT, + pretrained_window_sizes=window_sizes) + + if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): + load_pretrained(config, model, logger) + + else: + + errorMsg = f'Unknown encoder architecture {encoder_architecture}' + + logger.error(errorMsg) + + raise NotImplementedError(errorMsg) + + if decoder_architecture is not None: + + if encoder_architecture == 'swinv2': + + window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES + + model = SwinTransformerV2( + img_size=config.DATA.IMG_SIZE, + patch_size=config.MODEL.SWINV2.PATCH_SIZE, + in_chans=config.MODEL.SWINV2.IN_CHANS, + num_classes=config.MODEL.NUM_CLASSES, + embed_dim=config.MODEL.SWINV2.EMBED_DIM, + depths=config.MODEL.SWINV2.DEPTHS, + num_heads=config.MODEL.SWINV2.NUM_HEADS, + window_size=config.MODEL.SWINV2.WINDOW_SIZE, + mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, + qkv_bias=config.MODEL.SWINV2.QKV_BIAS, + drop_rate=config.MODEL.DROP_RATE, + drop_path_rate=config.MODEL.DROP_PATH_RATE, + ape=config.MODEL.SWINV2.APE, + patch_norm=config.MODEL.SWINV2.PATCH_NORM, + use_checkpoint=config.TRAIN.USE_CHECKPOINT, + pretrained_window_sizes=window_sizes) + + else: + + raise NotImplementedError() + + if decoder_architecture == 'unet': + + num_classes = config.MODEL.NUM_CLASSES + + if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): + load_pretrained(config, model, logger) + + model = unet_swin(encoder=model, num_classes=num_classes) + + else: + error_msg = f'Unknown decoder architecture: {decoder_architecture}' + raise NotImplementedError(error_msg) + + return model diff --git a/pytorch-caney/pytorch_caney/models/decoders/unet_decoder.py b/pytorch-caney/pytorch_caney/models/decoders/unet_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b55fcb30e1e9d38cca77aa4589794f415bdf4317 --- /dev/null +++ b/pytorch-caney/pytorch_caney/models/decoders/unet_decoder.py @@ -0,0 +1,181 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from segmentation_models_pytorch.base import modules as md + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + skip_channels, + out_channels, + use_batchnorm=True, + attention_type=None, + ): + super().__init__() + + self.conv1 = md.Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + + in_and_skip_channels = in_channels + skip_channels + + self.attention1 = md.Attention(attention_type, + in_channels=in_and_skip_channels) + + self.conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + + self.attention2 = md.Attention(attention_type, + in_channels=out_channels) + + self.in_channels = in_channels + self.out_channels = out_channels + self.skip_channels = skip_channels + + def forward(self, x, skip=None): + + if skip is None: + x = F.interpolate(x, scale_factor=2, mode="nearest") + + else: + + if x.shape[-1] != skip.shape[-1]: + x = F.interpolate(x, scale_factor=2, mode="nearest") + + if skip is not None: + + x = torch.cat([x, skip], dim=1) + x = self.attention1(x) + + x = self.conv1(x) + x = self.conv2(x) + x = self.attention2(x) + + return x + + +class CenterBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, use_batchnorm=True): + conv1 = md.Conv2dReLU( + in_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + super().__init__(conv1, conv2) + + +class UnetDecoder(nn.Module): + def __init__(self, + encoder_channels, + decoder_channels, + n_blocks=5, + use_batchnorm=True, + attention_type=None, + center=False): + super().__init__() + + if n_blocks != len(decoder_channels): + raise ValueError( + f"Model depth is {n_blocks}, but you provided " + f"decoder_channels for {len(decoder_channels)} blocks." + ) + + # remove first skip with same spatial resolution + encoder_channels = encoder_channels[1:] + + # reverse channels to start from head of encoder + encoder_channels = encoder_channels[::-1] + + # computing blocks input and output channels + head_channels = encoder_channels[0] + + in_channels = [head_channels] + list(decoder_channels[:-1]) + + skip_channels = list(encoder_channels[1:]) + [0] + + out_channels = decoder_channels + + if center: + + self.center = CenterBlock( + head_channels, head_channels, use_batchnorm=use_batchnorm) + + else: + + self.center = nn.Identity() + + # combine decoder keyword arguments + kwargs = dict(use_batchnorm=use_batchnorm, + attention_type=attention_type) + + blocks = [ + DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) + for in_ch, skip_ch, out_ch in zip(in_channels, + skip_channels, + out_channels) + ] + + self.blocks = nn.ModuleList(blocks) + + def forward(self, *features): + + features = features[1:] + + # remove first skip with same spatial resolution + + features = features[:: -1] + # reverse channels to start from head of encoder + + head = features[0] + + skips = features[1:] + + x = self.center(head) + + for i, decoder_block in enumerate(self.blocks): + + skip = skips[i] if i < len(skips) else None + + x = decoder_block(x, skip) + + return x + + +class SegmentationHead(nn.Sequential): + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + upsampling=1): + + conv2d = nn.Conv2d(in_channels, + out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2) + + upsampling = nn.UpsamplingBilinear2d( + scale_factor=upsampling) if upsampling > 1 else nn.Identity() + + super().__init__(conv2d, upsampling) diff --git a/pytorch-caney/pytorch_caney/models/maskrcnn_model.py b/pytorch-caney/pytorch_caney/models/maskrcnn_model.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/models/mim/__init__.py b/pytorch-caney/pytorch_caney/models/mim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/models/mim/__pycache__/__init__.cpython-310.pyc b/pytorch-caney/pytorch_caney/models/mim/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a1adf432f56de75df1356cc3bde087de116011a Binary files /dev/null and b/pytorch-caney/pytorch_caney/models/mim/__pycache__/__init__.cpython-310.pyc differ diff --git a/pytorch-caney/pytorch_caney/models/mim/__pycache__/mim.cpython-310.pyc b/pytorch-caney/pytorch_caney/models/mim/__pycache__/mim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..389c78e10cfe47181953ea1adc5fa2af9aab21e4 Binary files /dev/null and b/pytorch-caney/pytorch_caney/models/mim/__pycache__/mim.cpython-310.pyc differ diff --git a/pytorch-caney/pytorch_caney/models/mim/mim.py b/pytorch-caney/pytorch_caney/models/mim/mim.py new file mode 100644 index 0000000000000000000000000000000000000000..ebbf7d138bd7b93e6e0dafdb8e513f1662f30642 --- /dev/null +++ b/pytorch-caney/pytorch_caney/models/mim/mim.py @@ -0,0 +1,137 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import trunc_normal_ + +from ..swinv2_model import SwinTransformerV2 + + +class SwinTransformerV2ForSimMIM(SwinTransformerV2): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + assert self.num_classes == 0 + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + trunc_normal_(self.mask_token, mean=0., std=.02) + + def forward(self, x, mask): + x = self.patch_embed(x) + + assert mask is not None + B, L, _ = x.shape + + mask_tokens = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) + x = x * (1. - w) + mask_tokens * w + + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + x = self.norm(x) + + x = x.transpose(1, 2) + B, C, L = x.shape + H = W = int(L ** 0.5) + x = x.reshape(B, C, H, W) + return x + + @torch.jit.ignore + def no_weight_decay(self): + return super().no_weight_decay() | {'mask_token'} + + +class MiMModel(nn.Module): + """ + Masked-Image-Modeling model + + Given an encoder, makes a model that incorporates + the encoder and attaches a simple linear layer that + produces the raw-pixel predictions of the masked + inputs. + """ + def __init__(self, encoder, encoder_stride, in_chans, patch_size): + super().__init__() + self.encoder = encoder + self.encoder_stride = encoder_stride + self.in_chans = in_chans + self.patch_size = patch_size + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=self.encoder.num_features, + out_channels=self.encoder_stride ** 2 * self.in_chans, + kernel_size=1), + nn.PixelShuffle(self.encoder_stride), + ) + + # self.in_chans = self.encoder.in_chans + # self.patch_size = self.encoder.patch_size + + def forward(self, x, mask): + z = self.encoder(x, mask) + x_rec = self.decoder(z) + + mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave( + self.patch_size, 2).unsqueeze(1).contiguous() + loss_recon = F.l1_loss(x, x_rec, reduction='none') + loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans + return loss + + @torch.jit.ignore + def no_weight_decay(self): + if hasattr(self.encoder, 'no_weight_decay'): + return {'encoder.' + i for i in self.encoder.no_weight_decay()} + return {} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + if hasattr(self.encoder, 'no_weight_decay_keywords'): + return {'encoder.' + i for i in + self.encoder.no_weight_decay_keywords()} + return {} + + +def build_mim_model(config): + """Builds the masked-image-modeling model. + + Args: + config: config object + + Raises: + NotImplementedError: if the model is + not swinv2, then this will be thrown. + + Returns: + MiMModel: masked-image-modeling model + """ + model_type = config.MODEL.TYPE + if model_type == 'swinv2': + encoder = SwinTransformerV2ForSimMIM( + img_size=config.DATA.IMG_SIZE, + patch_size=config.MODEL.SWINV2.PATCH_SIZE, + in_chans=config.MODEL.SWINV2.IN_CHANS, + num_classes=0, + embed_dim=config.MODEL.SWINV2.EMBED_DIM, + depths=config.MODEL.SWINV2.DEPTHS, + num_heads=config.MODEL.SWINV2.NUM_HEADS, + window_size=config.MODEL.SWINV2.WINDOW_SIZE, + mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, + qkv_bias=config.MODEL.SWINV2.QKV_BIAS, + drop_rate=config.MODEL.DROP_RATE, + drop_path_rate=config.MODEL.DROP_PATH_RATE, + ape=config.MODEL.SWINV2.APE, + patch_norm=config.MODEL.SWINV2.PATCH_NORM, + use_checkpoint=config.TRAIN.USE_CHECKPOINT) + encoder_stride = 32 + in_chans = config.MODEL.SWINV2.IN_CHANS + patch_size = config.MODEL.SWINV2.PATCH_SIZE + else: + raise NotImplementedError(f"Unknown pre-train model: {model_type}") + + model = MiMModel(encoder=encoder, encoder_stride=encoder_stride, + in_chans=in_chans, patch_size=patch_size) + + return model diff --git a/pytorch-caney/pytorch_caney/models/simmim/__init__.py b/pytorch-caney/pytorch_caney/models/simmim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/models/simmim/simmim.py b/pytorch-caney/pytorch_caney/models/simmim/simmim.py new file mode 100644 index 0000000000000000000000000000000000000000..b13cfca7e06dc5da468012e983d07b8be37ae4e4 --- /dev/null +++ b/pytorch-caney/pytorch_caney/models/simmim/simmim.py @@ -0,0 +1,117 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import trunc_normal_ + +from ..swinv2_model import SwinTransformerV2 + + +class SwinTransformerV2ForSimMIM(SwinTransformerV2): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + assert self.num_classes == 0 + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + trunc_normal_(self.mask_token, mean=0., std=.02) + + def forward(self, x, mask): + x = self.patch_embed(x) + + assert mask is not None + B, L, _ = x.shape + + mask_tokens = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) + x = x * (1. - w) + mask_tokens * w + + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + x = self.norm(x) + + x = x.transpose(1, 2) + B, C, L = x.shape + H = W = int(L ** 0.5) + x = x.reshape(B, C, H, W) + return x + + @torch.jit.ignore + def no_weight_decay(self): + return super().no_weight_decay() | {'mask_token'} + + +class MiMModel(nn.Module): + def __init__(self, encoder, encoder_stride, in_chans, patch_size): + super().__init__() + self.encoder = encoder + self.encoder_stride = encoder_stride + self.in_chans = in_chans + self.patch_size = patch_size + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=self.encoder.num_features, + out_channels=self.encoder_stride ** 2 * self.in_chans, + kernel_size=1), + nn.PixelShuffle(self.encoder_stride), + ) + + # self.in_chans = self.encoder.in_chans + # self.patch_size = self.encoder.patch_size + + def forward(self, x, mask): + z = self.encoder(x, mask) + x_rec = self.decoder(z) + + mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave( + self.patch_size, 2).unsqueeze(1).contiguous() + loss_recon = F.l1_loss(x, x_rec, reduction='none') + loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans + return loss + + @torch.jit.ignore + def no_weight_decay(self): + if hasattr(self.encoder, 'no_weight_decay'): + return {'encoder.' + i for i in self.encoder.no_weight_decay()} + return {} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + if hasattr(self.encoder, 'no_weight_decay_keywords'): + return {'encoder.' + i for i in + self.encoder.no_weight_decay_keywords()} + return {} + + +def build_mim_model(config): + model_type = config.MODEL.TYPE + if model_type == 'swinv2': + encoder = SwinTransformerV2ForSimMIM( + img_size=config.DATA.IMG_SIZE, + patch_size=config.MODEL.SWINV2.PATCH_SIZE, + in_chans=config.MODEL.SWINV2.IN_CHANS, + num_classes=0, + embed_dim=config.MODEL.SWINV2.EMBED_DIM, + depths=config.MODEL.SWINV2.DEPTHS, + num_heads=config.MODEL.SWINV2.NUM_HEADS, + window_size=config.MODEL.SWINV2.WINDOW_SIZE, + mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, + qkv_bias=config.MODEL.SWINV2.QKV_BIAS, + drop_rate=config.MODEL.DROP_RATE, + drop_path_rate=config.MODEL.DROP_PATH_RATE, + ape=config.MODEL.SWINV2.APE, + patch_norm=config.MODEL.SWINV2.PATCH_NORM, + use_checkpoint=config.TRAIN.USE_CHECKPOINT) + encoder_stride = 32 + in_chans = config.MODEL.SWINV2.IN_CHANS + patch_size = config.MODEL.SWINV2.PATCH_SIZE + else: + raise NotImplementedError(f"Unknown pre-train model: {model_type}") + + model = MiMModel(encoder=encoder, encoder_stride=encoder_stride, + in_chans=in_chans, patch_size=patch_size) + + return model diff --git a/pytorch-caney/pytorch_caney/models/swinv2_model.py b/pytorch-caney/pytorch_caney/models/swinv2_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3a95b2a864fdf6566ce9b7fee6ae258c47eb7af9 --- /dev/null +++ b/pytorch-caney/pytorch_caney/models/swinv2_model.py @@ -0,0 +1,579 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +from pytorch_caney.network.mlp import Mlp +from pytorch_caney.network.attention import WindowAttention + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, + W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous( + ).view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, + window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, + key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. + Default: nn.LayerNorm + pretrained_window_size (int): Window size in pre-training. + """ + + def __init__(self, dim, input_resolution, num_heads, + window_size=7, shift_size=0, mlp_ratio=4., + qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, + pretrained_window_size=0): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, + # we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + + assert 0 <= self.shift_size < self.window_size, \ + "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size)) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + # nW, window_size, window_size, 1 + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view( + -1, + self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, + float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + # nW*B, window_size, window_size, C + x_windows = window_partition(shifted_x, self.window_size) + # nW*B, window_size*window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # nW*B, window_size*window_size, C + attn_windows = self.attn(x_windows, mask=self.attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, + self.window_size, self.window_size, C) + shifted_x = window_reverse( + attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=( + self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + x = shortcut + self.drop_path(self.norm1(x)) + + # FFN + x = x + self.drop_path(self.norm2(self.mlp(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}," \ + f"num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, " \ + f"shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. + Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + x = self.norm(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + flops += H * W * self.dim // 2 + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable + bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. + Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. + Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. + Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer + at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing + to save memory. Default: False. + pretrained_window_size (int): Local window size in pre-training. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, + use_checkpoint=False, pretrained_window_size=0): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if ( + i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance( + drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, " \ + f"input_resolution={self.input_resolution}," \ + f" depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. + Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // + patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, + kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W})" \ + f"doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * \ + (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformerV2(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical + Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. + Default: 3 + num_classes (int): Number of classes for classification head. + Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. + Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch + embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. + Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. + Default: False + pretrained_window_sizes (tuple(int)): Pretrained window sizes of + each layer. + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, + num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., + qkv_bias=True, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0.1, norm_layer=nn.LayerNorm, + ape=False, patch_norm=True, use_checkpoint=False, + pretrained_window_sizes=[0, 0, 0, 0], **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, + in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, + sum(depths))] + # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum( + depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if ( + i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pretrained_window_size=pretrained_window_sizes[i_layer]) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear( + self.num_features, num_classes) if \ + num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + for bly in self.layers: + bly._init_respostnorm() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"cpb_mlp", "logit_scale", 'relative_position_bias_table'} + + def forward_features(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + return x + + def extra_features(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + feature = [] + + for layer in self.layers: + x = layer(x) + bs, n, f = x.shape + h = int(n**0.5) + + feature.append( + x.view(-1, h, h, f).permute(0, 3, 1, 2).contiguous()) + return feature + + def get_unet_feature(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + bs, n, f = x.shape + h = int(n**0.5) + feature = [x.view(-1, h, h, f).permute(0, 3, 1, 2).contiguous()] + + for layer in self.layers: + x = layer(x) + bs, n, f = x.shape + h = int(n**0.5) + + feature.append( + x.view(-1, h, h, f).permute(0, 3, 1, 2).contiguous()) + return feature + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * \ + self.patches_resolution[0] * \ + self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops diff --git a/pytorch-caney/pytorch_caney/models/unet_model.py b/pytorch-caney/pytorch_caney/models/unet_model.py new file mode 100755 index 0000000000000000000000000000000000000000..b8e0779367af05e9b0182c72530ef508fc609ca4 --- /dev/null +++ b/pytorch-caney/pytorch_caney/models/unet_model.py @@ -0,0 +1,187 @@ +from pl_bolts.models.vision.unet import UNet +from pytorch_lightning import LightningModule +from pytorch_lightning.utilities.cli import MODEL_REGISTRY + +import torch +from torch.nn import functional as F +from torchmetrics import MetricCollection, Accuracy, IoU + + +# ------------------------------------------------------------------------------- +# class UNet +# This class performs training and classification of satellite imagery using a +# UNet CNN. +# ------------------------------------------------------------------------------- +@MODEL_REGISTRY +class UNetSegmentation(LightningModule): + + # --------------------------------------------------------------------------- + # __init__ + # --------------------------------------------------------------------------- + def __init__( + self, + input_channels: int = 4, + num_classes: int = 19, + num_layers: int = 5, + features_start: int = 64, + bilinear: bool = False, + ): + super().__init__() + + self.input_channels = input_channels + self.num_classes = num_classes + self.num_layers = num_layers + self.features_start = features_start + self.bilinear = bilinear + + self.net = UNet( + input_channels=self.input_channels, + num_classes=num_classes, + num_layers=self.num_layers, + features_start=self.features_start, + bilinear=self.bilinear, + ) + + metrics = MetricCollection( + [ + Accuracy(), IoU(num_classes=self.num_classes) + ] + ) + self.train_metrics = metrics.clone(prefix='train_') + self.val_metrics = metrics.clone(prefix='val_') + + # --------------------------------------------------------------------------- + # model methods + # --------------------------------------------------------------------------- + def forward(self, x): + return self.net(x) + + def training_step(self, batch, batch_nb): + img, mask = batch + img, mask = img.float(), mask.long() + + # Forward step, calculate logits and loss + logits = self(img) + # loss_val = F.cross_entropy(logits, mask) + + # Get target tensor from logits for metrics, calculate metrics + probs = torch.nn.functional.softmax(logits, dim=1) + probs = torch.argmax(probs, dim=1) + + # metrics_train = self.train_metrics(probs, mask) + # log_dict = {"train_loss": loss_val.detach()} + # return {"loss": loss_val, "log": log_dict, "progress_bar": log_dict} + # return { + # "loss": loss_val, "train_acc": metrics_train['train_Accuracy'], + # "train_iou": metrics_train['train_IoU'] + # } + + tensorboard_logs = self.train_metrics(probs, mask) + tensorboard_logs['loss'] = F.cross_entropy(logits, mask) + # tensorboard_logs['lr'] = self._get_current_lr() + + self.log( + 'acc', tensorboard_logs['train_Accuracy'], + sync_dist=True, prog_bar=True + ) + self.log( + 'iou', tensorboard_logs['train_IoU'], + sync_dist=True, prog_bar=True + ) + return tensorboard_logs + + def training_epoch_end(self, outputs): + pass + + # Get average metrics from multi-GPU batch sources + # loss_val = torch.stack([x["loss"] for x in outputs]).mean() + # acc_train = torch.stack([x["train_acc"] for x in outputs]).mean() + # iou_train = torch.stack([x["train_iou"] for x in outputs]).mean() + + # tensorboard_logs = self.train_metrics(probs, mask) + # tensorboard_logs['loss'] = F.cross_entropy(logits, mask) + # tensorboard_logs['lr'] = self._get_current_lr() + + # self.log( + # 'acc', tensorboard_logs['train_Accuracy'], + # sync_dist=True, prog_bar=True + # ) + # self.log( + # 'iou', tensorboard_logs['train_IoU'], + # sync_dist=True, prog_bar=True + # ) + # # Send output to logger + # self.log( + # "loss", loss_val, on_epoch=True, prog_bar=True, logger=True) + # self.log( + # "train_acc", acc_train, + # on_epoch=True, prog_bar=True, logger=True) + # self.log( + # "train_iou", iou_train, + # on_epoch=True, prog_bar=True, logger=True) + # return tensorboard_logs + + def validation_step(self, batch, batch_idx): + + # Get data, change type for validation + img, mask = batch + img, mask = img.float(), mask.long() + + # Forward step, calculate logits and loss + logits = self(img) + # loss_val = F.cross_entropy(logits, mask) + + # Get target tensor from logits for metrics, calculate metrics + probs = torch.nn.functional.softmax(logits, dim=1) + probs = torch.argmax(probs, dim=1) + # metrics_val = self.val_metrics(probs, mask) + + # return { + # "val_loss": loss_val, "val_acc": metrics_val['val_Accuracy'], + # "val_iou": metrics_val['val_IoU'] + # } + tensorboard_logs = self.val_metrics(probs, mask) + tensorboard_logs['val_loss'] = F.cross_entropy(logits, mask) + + self.log( + 'val_loss', tensorboard_logs['val_loss'], + sync_dist=True, prog_bar=True + ) + self.log( + 'val_acc', tensorboard_logs['val_Accuracy'], + sync_dist=True, prog_bar=True + ) + self.log( + 'val_iou', tensorboard_logs['val_IoU'], + sync_dist=True, prog_bar=True + ) + return tensorboard_logs + + # def validation_epoch_end(self, outputs): + + # # Get average metrics from multi-GPU batch sources + # loss_val = torch.stack([x["val_loss"] for x in outputs]).mean() + # acc_val = torch.stack([x["val_acc"] for x in outputs]).mean() + # iou_val = torch.stack([x["val_iou"] for x in outputs]).mean() + + # # Send output to logger + # self.log( + # "val_loss", torch.mean(self.all_gather(loss_val)), + # on_epoch=True, prog_bar=True, logger=True) + # self.log( + # "val_acc", torch.mean(self.all_gather(acc_val)), + # on_epoch=True, prog_bar=True, logger=True) + # self.log( + # "val_iou", torch.mean(self.all_gather(iou_val)), + # on_epoch=True, prog_bar=True, logger=True) + + # def configure_optimizers(self): + # opt = torch.optim.Adam(self.net.parameters(), lr=self.lr) + # sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10) + # return [opt], [sch] + + def test_step(self, batch, batch_idx, dataloader_idx=0): + return self(batch) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + return self(batch) diff --git a/pytorch-caney/pytorch_caney/models/unet_swin_model.py b/pytorch-caney/pytorch_caney/models/unet_swin_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8fa2982067aabf668bfc7f9973ac4fca528d6b67 --- /dev/null +++ b/pytorch-caney/pytorch_caney/models/unet_swin_model.py @@ -0,0 +1,44 @@ +from .decoders.unet_decoder import UnetDecoder +from .decoders.unet_decoder import SegmentationHead + +import torch.nn as nn + +from typing import Tuple + + +class unet_swin(nn.Module): + """ + Pytorch encoder-decoder model which pairs + an encoder (swin) with the attention unet + decoder. + """ + + FEATURE_CHANNELS: Tuple[int] = (3, 256, 512, 1024, 1024) + DECODE_CHANNELS: Tuple[int] = (512, 256, 128, 64) + IN_CHANNELS: int = 64 + N_BLOCKS: int = 4 + KERNEL_SIZE: int = 3 + UPSAMPLING: int = 4 + + def __init__(self, encoder, num_classes=9): + super().__init__() + + self.encoder = encoder + + self.decoder = UnetDecoder( + encoder_channels=self.FEATURE_CHANNELS, + n_blocks=self.N_BLOCKS, + decoder_channels=self.DECODE_CHANNELS, + attention_type=None) + self.segmentation_head = SegmentationHead( + in_channels=self.IN_CHANNELS, + out_channels=num_classes, + kernel_size=self.KERNEL_SIZE, + upsampling=self.UPSAMPLING) + + def forward(self, x): + encoder_featrue = self.encoder.get_unet_feature(x) + decoder_output = self.decoder(*encoder_featrue) + masks = self.segmentation_head(decoder_output) + + return masks diff --git a/pytorch-caney/pytorch_caney/network/__init__.py b/pytorch-caney/pytorch_caney/network/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/network/__pycache__/__init__.cpython-310.pyc b/pytorch-caney/pytorch_caney/network/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a597934bd85e8e895d805abd89d2c6ad190a5c1 Binary files /dev/null and b/pytorch-caney/pytorch_caney/network/__pycache__/__init__.cpython-310.pyc differ diff --git a/pytorch-caney/pytorch_caney/network/__pycache__/attention.cpython-310.pyc b/pytorch-caney/pytorch_caney/network/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bef47bbb9ce801f75d0225fd0697b67b7932688c Binary files /dev/null and b/pytorch-caney/pytorch_caney/network/__pycache__/attention.cpython-310.pyc differ diff --git a/pytorch-caney/pytorch_caney/network/__pycache__/mlp.cpython-310.pyc b/pytorch-caney/pytorch_caney/network/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14d59eb2fd95e186b0e5e1761c0f5239a0ea73a8 Binary files /dev/null and b/pytorch-caney/pytorch_caney/network/__pycache__/mlp.cpython-310.pyc differ diff --git a/pytorch-caney/pytorch_caney/network/attention.py b/pytorch-caney/pytorch_caney/network/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2941d756b31496db69840e5b24a0b89b712a000c --- /dev/null +++ b/pytorch-caney/pytorch_caney/network/attention.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +class WindowAttention(nn.Module): + """ + Window based multi-head self attention (W-MSA) module with + relative position bias. It supports both of shifted and + non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, + key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the + window in pre-training. + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + pretrained_window_size=[0, 0]): + + super().__init__() + + self.dim = dim + + self.window_size = window_size # Wh, Ww + + self.pretrained_window_size = pretrained_window_size + + self.num_heads = num_heads + + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False)) + + # get relative_coords_table + relative_coords_h = torch.arange( + -(self.window_size[0] - 1), + self.window_size[0], + dtype=torch.float32) + relative_coords_w = torch.arange( + -(self.window_size[1] - 1), + self.window_size[1], + dtype=torch.float32) + + # 1, 2*Wh-1, 2*Ww-1, 2 + relative_coords_table = torch.stack( + torch.meshgrid( + [relative_coords_h, + relative_coords_w])).permute(1, + 2, + 0).contiguous().unsqueeze(0) + + if pretrained_window_size[0] > 0: + + relative_coords_table[:, :, :, + 0] /= (pretrained_window_size[0] - 1) + + relative_coords_table[:, :, :, + 1] /= (pretrained_window_size[1] - 1) + + else: + + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + + relative_coords_table *= 8 # normalize to -8, 8 + + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside + # the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + + relative_coords = coords_flatten[:, :, None] - \ + coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + + relative_coords[:, :, 0] += self.window_size[0] - \ + 1 # shift to start from 0 + + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + + self.register_buffer("relative_position_index", + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + + if qkv_bias: + + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + + else: + + self.q_bias = None + self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) + or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like( + self.v_bias, requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv[0], qkv[1], qkv[2] + + # cosine attention + attn = (F.normalize(q, dim=-1) @ + F.normalize(k, dim=-1).transpose(-2, -1)) + # logit_scale = torch.clamp( + # self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() + logit_scale = torch.clamp(self.logit_scale, max=torch.log( + torch.tensor(1. / 0.01))).exp() # .to(self.logit_scale.get_device()) + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp( + self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = \ + relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) + # Wh*Ww,Wh*Ww,nH + + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, ' \ + f'pretrained_window_size={self.pretrained_window_size}, ' \ + f'num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops diff --git a/pytorch-caney/pytorch_caney/network/mlp.py b/pytorch-caney/pytorch_caney/network/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..d15480836f6c416b55aa12148bbe3f83add434ec --- /dev/null +++ b/pytorch-caney/pytorch_caney/network/mlp.py @@ -0,0 +1,21 @@ +import torch.nn as nn + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py b/pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3b05b18091e105472547fa8fa3ae5385f79296 --- /dev/null +++ b/pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py @@ -0,0 +1,454 @@ +from pytorch_caney.models.build import build_model + +from pytorch_caney.data.datamodules.finetune_datamodule \ + import build_finetune_dataloaders + +from pytorch_caney.training.mim_utils \ + import build_optimizer, save_checkpoint, reduce_tensor + +from pytorch_caney.config import get_config +from pytorch_caney.loss.build import build_loss +from pytorch_caney.lr_scheduler import build_scheduler, setup_scaled_lr +from pytorch_caney.ptc_logging import create_logger +from pytorch_caney.training.mim_utils import get_grad_norm + +import argparse +import datetime +import joblib +import numpy as np +import os +import time + +import torch +import torch.cuda.amp as amp +import torch.backends.cudnn as cudnn +import torch.distributed as dist + +from timm.utils import AverageMeter + + +def parse_args(): + """ + Parse command-line arguments + """ + + parser = argparse.ArgumentParser( + 'pytorch-caney finetuning', + add_help=False) + + parser.add_argument( + '--cfg', + type=str, + required=True, + metavar="FILE", + help='path to config file') + + parser.add_argument( + "--data-paths", + nargs='+', + required=True, + help="paths where dataset is stored") + + parser.add_argument( + '--dataset', + type=str, + required=True, + help='Dataset to use') + + parser.add_argument( + '--pretrained', + type=str, + help='path to pre-trained model') + + parser.add_argument( + '--batch-size', + type=int, + help="batch size for single GPU") + + parser.add_argument( + '--resume', + help='resume from checkpoint') + + parser.add_argument( + '--accumulation-steps', + type=int, + help="gradient accumulation steps") + + parser.add_argument( + '--use-checkpoint', + action='store_true', + help="whether to use gradient checkpointing to save memory") + + parser.add_argument( + '--enable-amp', + action='store_true') + + parser.add_argument( + '--disable-amp', + action='store_false', + dest='enable_amp') + + parser.set_defaults(enable_amp=True) + + parser.add_argument( + '--output', + default='output', + type=str, + metavar='PATH', + help='root of output folder, the full path is ' + + '// (default: output)') + + parser.add_argument( + '--tag', + help='tag of experiment') + + args = parser.parse_args() + + config = get_config(args) + + return args, config + + +def train(config, + dataloader_train, + dataloader_val, + model, + model_wo_ddp, + optimizer, + lr_scheduler, + scaler, + criterion): + """ + Start fine-tuning a specific model and dataset. + + Args: + config: config object + dataloader_train: training pytorch dataloader + dataloader_val: validation pytorch dataloader + model: model to pre-train + model_wo_ddp: model to pre-train that is not the DDP version + optimizer: pytorch optimizer + lr_scheduler: learning-rate scheduler + scaler: loss scaler + criterion: loss function to use for fine-tuning + """ + + logger.info("Start fine-tuning") + + start_time = time.time() + + for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): + + dataloader_train.sampler.set_epoch(epoch) + + execute_one_epoch(config, model, dataloader_train, + optimizer, criterion, epoch, lr_scheduler, scaler) + + loss = validate(config, model, dataloader_val, criterion) + + logger.info(f'Model validation loss: {loss:.3f}%') + + if dist.get_rank() == 0 and \ + (epoch % config.SAVE_FREQ == 0 or + epoch == (config.TRAIN.EPOCHS - 1)): + + save_checkpoint(config, epoch, model_wo_ddp, 0., + optimizer, lr_scheduler, scaler, logger) + + total_time = time.time() - start_time + + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + + logger.info('Training time {}'.format(total_time_str)) + + +def execute_one_epoch(config, + model, + dataloader, + optimizer, + criterion, + epoch, + lr_scheduler, + scaler): + """ + Execute training iterations on a single epoch. + + Args: + config: config object + model: model to pre-train + dataloader: dataloader to use + optimizer: pytorch optimizer + epoch: int epoch number + lr_scheduler: learning-rate scheduler + scaler: loss scaler + """ + model.train() + + optimizer.zero_grad() + + num_steps = len(dataloader) + + # Set up logging meters + batch_time = AverageMeter() + data_time = AverageMeter() + loss_meter = AverageMeter() + norm_meter = AverageMeter() + loss_scale_meter = AverageMeter() + + start = time.time() + end = time.time() + for idx, (samples, targets) in enumerate(dataloader): + + data_time.update(time.time() - start) + + samples = samples.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + + with amp.autocast(enabled=config.ENABLE_AMP): + logits = model(samples) + + if config.TRAIN.ACCUMULATION_STEPS > 1: + loss = criterion(logits, targets) + loss = loss / config.TRAIN.ACCUMULATION_STEPS + scaler.scale(loss).backward() + if config.TRAIN.CLIP_GRAD: + scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), + config.TRAIN.CLIP_GRAD) + else: + grad_norm = get_grad_norm(model.parameters()) + if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: + scaler.step(optimizer) + optimizer.zero_grad() + scaler.update() + lr_scheduler.step_update(epoch * num_steps + idx) + else: + loss = criterion(logits, targets) + optimizer.zero_grad() + scaler.scale(loss).backward() + if config.TRAIN.CLIP_GRAD: + scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), + config.TRAIN.CLIP_GRAD) + else: + grad_norm = get_grad_norm(model.parameters()) + scaler.step(optimizer) + scaler.update() + lr_scheduler.step_update(epoch * num_steps + idx) + + torch.cuda.synchronize() + + loss_meter.update(loss.item(), targets.size(0)) + norm_meter.update(grad_norm) + loss_scale_meter.update(scaler.get_scale()) + batch_time.update(time.time() - end) + end = time.time() + + if idx % config.PRINT_FREQ == 0: + lr = optimizer.param_groups[0]['lr'] + memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) + etas = batch_time.avg * (num_steps - idx) + logger.info( + f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' + f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' + f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' + f'data_time {data_time.val:.4f} ({data_time.avg:.4f})\t' + f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' + f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' + f'loss_scale {loss_scale_meter.val:.4f}' + + f' ({loss_scale_meter.avg:.4f})\t' + f'mem {memory_used:.0f}MB') + + epoch_time = time.time() - start + logger.info( + f"EPOCH {epoch} training takes " + + f"{datetime.timedelta(seconds=int(epoch_time))}") + + +@torch.no_grad() +def validate(config, model, dataloader, criterion): + """Validation function which given a model and validation loader + performs a validation run and returns the average loss according + to the criterion. + + Args: + config: config object + model: pytorch model to validate + dataloader: pytorch validation loader + criterion: pytorch-friendly loss function + + Returns: + loss_meter.avg: average of the loss throught the validation + iterations + """ + + model.eval() + + batch_time = AverageMeter() + + loss_meter = AverageMeter() + + end = time.time() + + for idx, (images, target) in enumerate(dataloader): + + images = images.cuda(non_blocking=True) + + target = target.cuda(non_blocking=True) + + # compute output + output = model(images) + + # measure accuracy and record loss + loss = criterion(output, target.long()) + + loss = reduce_tensor(loss) + + loss_meter.update(loss.item(), target.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + + end = time.time() + + if idx % config.PRINT_FREQ == 0: + + memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) + + logger.info( + f'Test: [{idx}/{len(dataloader)}]\t' + f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' + f'Mem {memory_used:.0f}MB') + + return loss_meter.avg + + +def main(config): + """ + Performs the main function of building model, loader, etc. and starts + training. + """ + + dataloader_train, dataloader_val = build_finetune_dataloaders( + config, logger) + + model = build_finetune_model(config, logger) + + optimizer = build_optimizer(config, + model, + is_pretrain=False, + logger=logger) + + model, model_wo_ddp = make_ddp(model) + + n_iter_per_epoch = len(dataloader_train) + + lr_scheduler = build_scheduler(config, optimizer, n_iter_per_epoch) + + scaler = amp.GradScaler() + + criterion = build_loss(config) + + train(config, + dataloader_train, + dataloader_val, + model, + model_wo_ddp, + optimizer, + lr_scheduler, + scaler, + criterion) + + +def build_finetune_model(config, logger): + + logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") + + model = build_model(config, + pretrain=False, + pretrain_method='mim', + logger=logger) + + model.cuda() + + logger.info(str(model)) + + return model + + +def make_ddp(model): + + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[int(os.environ["RANK"])], + broadcast_buffers=False, + find_unused_parameters=True) + + model_without_ddp = model.module + + return model, model_without_ddp + + +def setup_rank_worldsize(): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ['WORLD_SIZE']) + print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") + else: + rank = -1 + world_size = -1 + return rank, world_size + + +def setup_distributed_processing(rank, world_size): + torch.cuda.set_device(int(os.environ["RANK"])) + torch.distributed.init_process_group( + backend='nccl', init_method='env://', world_size=world_size, rank=rank) + torch.distributed.barrier() + + +def setup_seeding(config): + seed = config.SEED + dist.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + +if __name__ == '__main__': + _, config = parse_args() + + rank, world_size = setup_rank_worldsize() + + setup_distributed_processing(rank, world_size) + + setup_seeding(config) + + cudnn.benchmark = True + + linear_scaled_lr, linear_scaled_min_lr, linear_scaled_warmup_lr = \ + setup_scaled_lr(config) + + config.defrost() + config.TRAIN.BASE_LR = linear_scaled_lr + config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr + config.TRAIN.MIN_LR = linear_scaled_min_lr + config.freeze() + + os.makedirs(config.OUTPUT, exist_ok=True) + logger = create_logger(output_dir=config.OUTPUT, + dist_rank=dist.get_rank(), + name=f"{config.MODEL.NAME}") + + if dist.get_rank() == 0: + path = os.path.join(config.OUTPUT, "config.json") + with open(path, "w") as f: + f.write(config.dump()) + logger.info(f"Full config saved to {path}") + logger.info(config.dump()) + config_file_name = f'{config.TAG}.config.sav' + config_file_path = os.path.join(config.OUTPUT, config_file_name) + joblib.dump(config, config_file_path) + + main(config) diff --git a/pytorch-caney/pytorch_caney/pipelines/modis_segmentation.py b/pytorch-caney/pytorch_caney/pipelines/modis_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..2c58e9c1e7de786394ec5acdedd90273cc49a1cf --- /dev/null +++ b/pytorch-caney/pytorch_caney/pipelines/modis_segmentation.py @@ -0,0 +1,364 @@ +from argparse import ArgumentParser, Namespace +import multiprocessing + +import torch +from torch import nn +import torch.nn.functional as F +from torch.utils.data import DataLoader + +import torchvision.transforms as transforms + +from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from lightning.pytorch.loggers import CSVLogger + +from pytorch_caney.datasets.modis_dataset import MODISDataset +from pytorch_caney.utils import check_gpus_available + + +class UNet(nn.Module): + """ + Architecture based on U-Net: Convolutional Networks for + Biomedical Image Segmentation. + Link - https://arxiv.org/abs/1505.04597 + >>> UNet(num_classes=2, num_layers=3) \ + # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + UNet( + (layers): ModuleList( + (0): DoubleConv(...) + (1): Down(...) + (2): Down(...) + (3): Up(...) + (4): Up(...) + (5): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1)) + ) + ) + """ + + def __init__( + self, + num_channels: int = 7, + num_classes: int = 19, + num_layers: int = 5, + features_start: int = 64, + bilinear: bool = False + ): + + super().__init__() + self.num_layers = num_layers + + layers = [DoubleConv(num_channels, features_start)] + + feats = features_start + for _ in range(num_layers - 1): + layers.append(Down(feats, feats * 2)) + feats *= 2 + + for _ in range(num_layers - 1): + layers.append(Up(feats, feats // 2, bilinear)) + feats //= 2 + + layers.append(nn.Conv2d(feats, num_classes, kernel_size=1)) + + self.layers = nn.ModuleList(layers) + + def forward(self, x): + xi = [self.layers[0](x)] + # Down path + for layer in self.layers[1: self.num_layers]: + xi.append(layer(xi[-1])) + # Up path + for i, layer in enumerate(self.layers[self.num_layers: -1]): + xi[-1] = layer(xi[-1], xi[-2 - i]) + return self.layers[-1](xi[-1]) + + +class DoubleConv(nn.Module): + """Double Convolution and BN and ReLU (3x3 conv -> BN -> ReLU) ** 2. + >>> DoubleConv(4, 4) \ + # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + DoubleConv( + (net): Sequential(...) + ) + """ + + def __init__(self, in_ch: int, out_ch: int): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.net(x) + + +class Down(nn.Module): + """Combination of MaxPool2d and DoubleConv in series. + >>> Down(4, 8) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Down( + (net): Sequential( + (0): MaxPool2d( + kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (1): DoubleConv( + (net): Sequential(...) + ) + ) + ) + """ + + def __init__(self, in_ch: int, out_ch: int): + super().__init__() + self.net = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), DoubleConv(in_ch, out_ch)) + + def forward(self, x): + return self.net(x) + + +class Up(nn.Module): + """Upsampling (by either bilinear interpolation or transpose convolutions) + followed by concatenation of feature + map from contracting path, followed by double 3x3 convolution. + >>> Up(8, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Up( + (upsample): ConvTranspose2d(8, 4, kernel_size=(2, 2), stride=(2, 2)) + (conv): DoubleConv( + (net): Sequential(...) + ) + ) + """ + + def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False): + super().__init__() + self.upsample = None + if bilinear: + self.upsample = nn.Sequential( + nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d( + in_ch, in_ch // 2, kernel_size=1), + ) + else: + self.upsample = nn.ConvTranspose2d( + in_ch, in_ch // 2, kernel_size=2, stride=2) + + self.conv = DoubleConv(in_ch, out_ch) + + def forward(self, x1, x2): + x1 = self.upsample(x1) + + # Pad x1 to the size of x2 + diff_h = x2.shape[2] - x1.shape[2] + diff_w = x2.shape[3] - x1.shape[3] + + x1 = F.pad( + x1, + [ + diff_w // 2, diff_w - diff_w // 2, + diff_h // 2, diff_h - diff_h // 2 + ]) + + # Concatenate along the channels axis + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class SegmentationModel(LightningModule): + + def __init__( + self, + data_path: list = [], + n_classes: int = 18, + batch_size: int = 256, + lr: float = 3e-4, + num_layers: int = 5, + features_start: int = 64, + bilinear: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.data_paths = data_path + self.n_classes = n_classes + self.batch_size = batch_size + self.learning_rate = lr + self.num_layers = num_layers + self.features_start = features_start + self.bilinear = bilinear + self.validation_step_outputs = [] + + self.net = UNet( + num_classes=self.n_classes, + num_layers=self.num_layers, + features_start=self.features_start, + bilinear=self.bilinear + ) + self.transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.0173, 0.0332, 0.0088, + 0.0136, 0.0381, 0.0348, 0.0249], + std=[0.0150, 0.0127, 0.0124, + 0.0128, 0.0120, 0.0159, 0.0164] + ), + ] + ) + print('> Init datasets') + self.trainset = MODISDataset( + self.data_paths, split="train", transform=self.transform) + self.validset = MODISDataset( + self.data_paths, split="valid", transform=self.transform) + print('Done init datasets') + + def forward(self, x): + return self.net(x) + + def training_step(self, batch, batch_nb): + img, mask = batch + img = img.float() + mask = mask.long() + out = self(img) + loss = F.cross_entropy(out, mask, ignore_index=250) + log_dict = {"train_loss": loss} + self.log_dict(log_dict) + return {"loss": loss, "log": log_dict, "progress_bar": log_dict} + + def validation_step(self, batch, batch_idx): + img, mask = batch + img = img.float() + mask = mask.long() + out = self(img) + loss_val = F.cross_entropy(out, mask, ignore_index=250) + self.validation_step_outputs.append(loss_val) + return {"val_loss": loss_val} + + def on_validation_epoch_end(self): + loss_val = torch.stack(self.validation_step_outputs).mean() + log_dict = {"val_loss": loss_val} + self.log("val_loss", loss_val, sync_dist=True) + self.validation_step_outputs.clear() + return { + "log": log_dict, + "val_loss": log_dict["val_loss"], + "progress_bar": log_dict + } + + def configure_optimizers(self): + opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate) + # sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10) + return [opt] # , [sch] + + def train_dataloader(self): + return DataLoader( + self.trainset, + batch_size=self.batch_size, + num_workers=multiprocessing.cpu_count(), + shuffle=True + ) + + def val_dataloader(self): + return DataLoader( + self.validset, + batch_size=self.batch_size, + num_workers=multiprocessing.cpu_count(), + shuffle=False + ) + + +def main(hparams: Namespace): + # ------------------------ + # 1 INIT LIGHTNING MODEL + # ------------------------ + ngpus = int(hparams.ngpus) + # PT ligtning does not expect this, del after use + del hparams.ngpus + + model = SegmentationModel(**vars(hparams)) + + # ------------------------ + # 2 SET LOGGER + # ------------------------ + # logger = True + # if hparams.log_wandb: + # logger = WandbLogger() + # # optional: log model topology + # logger.watch(model.net) + + train_callbacks = [ + # TQDMProgressBar(refresh_rate=20), + ModelCheckpoint(dirpath='models/', + monitor='val_loss', + save_top_k=5, + filename='{epoch}-{val_loss:.2f}.ckpt'), + EarlyStopping("val_loss", patience=10, mode='min'), + ] + + # See number of devices + check_gpus_available(ngpus) + + # ------------------------ + # 3 INIT TRAINER + # ------------------------ + # trainer = Trainer( + # ------------------------ + trainer = Trainer( + accelerator="gpu", + devices=ngpus, + strategy="ddp", + min_epochs=1, + max_epochs=500, + callbacks=train_callbacks, + logger=CSVLogger(save_dir="logs/"), + # precision=16 # makes loss nan, need to fix that + ) + + # ------------------------ + # 5 START TRAINING + # ------------------------ + trainer.fit(model) + trainer.save_checkpoint("best_model.ckpt") + + # ------------------------ + # 6 START TEST + # ------------------------ + # test_set = MODISDataset( + # self.data_path, split=None, transform=self.transform) + # test_dataloader = DataLoader(...) + # trainer.test(ckpt_path="best", dataloaders=) + + +if __name__ == "__main__": + cli_lightning_logo() + + parser = ArgumentParser() + parser.add_argument( + "--data_path", nargs='+', required=True, + help="path where dataset is stored") + parser.add_argument('--ngpus', type=int, + default=torch.cuda.device_count(), + help='number of gpus to use') + parser.add_argument( + "--n-classes", type=int, default=18, help="number of classes") + parser.add_argument( + "--batch_size", type=int, default=256, help="size of the batches") + parser.add_argument( + "--lr", type=float, default=3e-4, help="adam: learning rate") + parser.add_argument( + "--num_layers", type=int, default=5, help="number of layers on u-net") + parser.add_argument( + "--features_start", type=float, default=64, + help="number of features in first layer") + parser.add_argument( + "--bilinear", action="store_true", default=False, + help="whether to use bilinear interpolation or transposed") + # parser.add_argument( + # "--log-wandb", action="store_true", default=True, + # help="whether to use wandb as the logger") + hparams = parser.parse_args() + + main(hparams) diff --git a/pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py b/pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py new file mode 100644 index 0000000000000000000000000000000000000000..3bcc7953ae57cbb53202876f5ea7838a09ed0baf --- /dev/null +++ b/pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py @@ -0,0 +1,371 @@ +from pytorch_caney.data.datamodules.mim_datamodule \ + import build_mim_dataloader + +from pytorch_caney.models.mim.mim \ + import build_mim_model + +from pytorch_caney.training.mim_utils \ + import build_optimizer, save_checkpoint + +from pytorch_caney.training.mim_utils import get_grad_norm +from pytorch_caney.lr_scheduler import build_scheduler, setup_scaled_lr +from pytorch_caney.ptc_logging import create_logger +from pytorch_caney.config import get_config + +import argparse +import datetime +import joblib +import numpy as np +import os +import time + +import torch +import torch.cuda.amp as amp +import torch.backends.cudnn as cudnn +import torch.distributed as dist + +from timm.utils import AverageMeter + + +def parse_args(): + """ + Parse command-line arguments + """ + parser = argparse.ArgumentParser( + 'pytorch-caney implementation of MiM pre-training script', + add_help=False) + + parser.add_argument( + '--cfg', + type=str, + required=True, + metavar="FILE", + help='path to config file') + + parser.add_argument( + "--data-paths", + nargs='+', + required=True, + help="paths where dataset is stored") + + parser.add_argument( + '--dataset', + type=str, + required=True, + help='Dataset to use') + + parser.add_argument( + '--batch-size', + type=int, + help="batch size for single GPU") + + parser.add_argument( + '--resume', + help='resume from checkpoint') + + parser.add_argument( + '--accumulation-steps', + type=int, + help="gradient accumulation steps") + + parser.add_argument( + '--use-checkpoint', + action='store_true', + help="whether to use gradient checkpointing to save memory") + + parser.add_argument( + '--enable-amp', + action='store_true') + + parser.add_argument( + '--disable-amp', + action='store_false', + dest='enable_amp') + + parser.set_defaults(enable_amp=True) + + parser.add_argument( + '--output', + default='output', + type=str, + metavar='PATH', + help='root of output folder, the full path is ' + + '// (default: output)') + + parser.add_argument( + '--tag', + help='tag of experiment') + + args = parser.parse_args() + + config = get_config(args) + + return args, config + + +def train(config, + dataloader, + model, + model_wo_ddp, + optimizer, + lr_scheduler, + scaler): + """ + Start pre-training a specific model and dataset. + + Args: + config: config object + dataloader: dataloader to use + model: model to pre-train + model_wo_ddp: model to pre-train that is not the DDP version + optimizer: pytorch optimizer + lr_scheduler: learning-rate scheduler + scaler: loss scaler + """ + + logger.info("Start training") + + start_time = time.time() + + for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): + + dataloader.sampler.set_epoch(epoch) + + execute_one_epoch(config, model, dataloader, + optimizer, epoch, lr_scheduler, scaler) + + if dist.get_rank() == 0 and \ + (epoch % config.SAVE_FREQ == 0 or + epoch == (config.TRAIN.EPOCHS - 1)): + + save_checkpoint(config, epoch, model_wo_ddp, 0., + optimizer, lr_scheduler, scaler, logger) + + total_time = time.time() - start_time + + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + + logger.info('Training time {}'.format(total_time_str)) + + +def execute_one_epoch(config, + model, + dataloader, + optimizer, + epoch, + lr_scheduler, + scaler): + """ + Execute training iterations on a single epoch. + + Args: + config: config object + model: model to pre-train + dataloader: dataloader to use + optimizer: pytorch optimizer + epoch: int epoch number + lr_scheduler: learning-rate scheduler + scaler: loss scaler + """ + + model.train() + + optimizer.zero_grad() + + num_steps = len(dataloader) + + # Set up logging meters + batch_time = AverageMeter() + data_time = AverageMeter() + loss_meter = AverageMeter() + norm_meter = AverageMeter() + loss_scale_meter = AverageMeter() + + start = time.time() + end = time.time() + for idx, (img, mask, _) in enumerate(dataloader): + + data_time.update(time.time() - start) + + img = img.cuda(non_blocking=True) + mask = mask.cuda(non_blocking=True) + + with amp.autocast(enabled=config.ENABLE_AMP): + loss = model(img, mask) + + if config.TRAIN.ACCUMULATION_STEPS > 1: + loss = loss / config.TRAIN.ACCUMULATION_STEPS + scaler.scale(loss).backward() + loss.backward() + if config.TRAIN.CLIP_GRAD: + scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), + config.TRAIN.CLIP_GRAD) + else: + grad_norm = get_grad_norm(model.parameters()) + if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: + scaler.step(optimizer) + optimizer.zero_grad() + scaler.update() + lr_scheduler.step_update(epoch * num_steps + idx) + else: + optimizer.zero_grad() + scaler.scale(loss).backward() + if config.TRAIN.CLIP_GRAD: + scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), + config.TRAIN.CLIP_GRAD) + else: + grad_norm = get_grad_norm(model.parameters()) + scaler.step(optimizer) + scaler.update() + lr_scheduler.step_update(epoch * num_steps + idx) + + torch.cuda.synchronize() + + loss_meter.update(loss.item(), img.size(0)) + norm_meter.update(grad_norm) + loss_scale_meter.update(scaler.get_scale()) + batch_time.update(time.time() - end) + end = time.time() + + if idx % config.PRINT_FREQ == 0: + lr = optimizer.param_groups[0]['lr'] + memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) + etas = batch_time.avg * (num_steps - idx) + logger.info( + f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' + f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' + f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' + f'data_time {data_time.val:.4f} ({data_time.avg:.4f})\t' + f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' + f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' + f'loss_scale {loss_scale_meter.val:.4f}' + + f' ({loss_scale_meter.avg:.4f})\t' + f'mem {memory_used:.0f}MB') + + epoch_time = time.time() - start + logger.info( + f"EPOCH {epoch} training takes " + + f"{datetime.timedelta(seconds=int(epoch_time))}") + + +def main(config): + """ + Starts training process after building the proper model, optimizer, etc. + + Args: + config: config object + """ + + pretrain_data_loader = build_mim_dataloader(config, logger) + + simmim_model = build_model(config, logger) + + simmim_optimizer = build_optimizer(config, + simmim_model, + is_pretrain=True, + logger=logger) + + model, model_wo_ddp = make_ddp(simmim_model) + + n_iter_per_epoch = len(pretrain_data_loader) + + lr_scheduler = build_scheduler(config, simmim_optimizer, n_iter_per_epoch) + + scaler = amp.GradScaler() + + train(config, + pretrain_data_loader, + model, + model_wo_ddp, + simmim_optimizer, + lr_scheduler, + scaler) + + +def build_model(config, logger): + + logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") + + model = build_mim_model(config) + + model.cuda() + + logger.info(str(model)) + + return model + + +def make_ddp(model): + + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[int(os.environ["RANK"])], broadcast_buffers=False) + + model_without_ddp = model.module + + return model, model_without_ddp + + +def setup_rank_worldsize(): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ['WORLD_SIZE']) + print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") + else: + rank = -1 + world_size = -1 + return rank, world_size + + +def setup_distributed_processing(rank, world_size): + torch.cuda.set_device(int(os.environ["RANK"])) + torch.distributed.init_process_group( + backend='nccl', init_method='env://', world_size=world_size, rank=rank) + torch.distributed.barrier() + + +def setup_seeding(config): + seed = config.SEED + dist.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + +if __name__ == '__main__': + _, config = parse_args() + + rank, world_size = setup_rank_worldsize() + + setup_distributed_processing(rank, world_size) + + setup_seeding(config) + + cudnn.benchmark = True + + linear_scaled_lr, linear_scaled_min_lr, linear_scaled_warmup_lr = \ + setup_scaled_lr(config) + + config.defrost() + config.TRAIN.BASE_LR = linear_scaled_lr + config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr + config.TRAIN.MIN_LR = linear_scaled_min_lr + config.freeze() + + os.makedirs(config.OUTPUT, exist_ok=True) + logger = create_logger(output_dir=config.OUTPUT, + dist_rank=dist.get_rank(), + name=f"{config.MODEL.NAME}") + + if dist.get_rank() == 0: + path = os.path.join(config.OUTPUT, "config.json") + with open(path, "w") as f: + f.write(config.dump()) + logger.info(f"Full config saved to {path}") + logger.info(config.dump()) + config_file_name = f'{config.TAG}.config.sav' + config_file_path = os.path.join(config.OUTPUT, config_file_name) + joblib.dump(config, config_file_path) + + main(config) diff --git a/pytorch-caney/pytorch_caney/processing.py b/pytorch-caney/pytorch_caney/processing.py new file mode 100755 index 0000000000000000000000000000000000000000..30723d07af06e5b13e54eddde4f93c19ee3fe837 --- /dev/null +++ b/pytorch-caney/pytorch_caney/processing.py @@ -0,0 +1,410 @@ +import logging +import random +from tqdm import tqdm + +import numpy as np +from numpy import fliplr, flipud + +import scipy.signal + + +SEED = 42 +np.random.seed(SEED) + +__author__ = "Jordan A Caraballo-Vega, Science Data Processing Branch" +__email__ = "jordan.a.caraballo-vega@nasa.gov" +__status__ = "Production" + +# ---------------------------------------------------------------------------- +# module processing +# +# General functions to perform standardization of images (numpy arrays). +# A couple of methods have been implemented for testing, including global and +# local standardization for neural networks input. Data manipulation stage, +# extract random patches for training and store them in numpy arrays. +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# Module Methods +# --------------------------------------------------------------------------- + + +# --------------------------- Normalization Functions ----------------------- # +def normalize(images, factor=65535.0) -> np.array: + """ + Normalize numpy array in the range of [0,1] + :param images: numpy array in the format (n,w,h,c). + :param factor: float number to normalize images, e.g. 2^(16)-1 + :return: numpy array in the [0,1] range + """ + return images / factor + + +# ------------------------ Standardization Functions ----------------------- # +def global_standardization(images, strategy='per-batch') -> np.array: + """ + Standardize numpy array using global standardization. + :param images: numpy array in the format (n,w,h,c). + :param strategy: can select between per-image or per-batch. + :return: globally standardized numpy array + """ + if strategy == 'per-batch': + mean = np.mean(images) # global mean of all images + std = np.std(images) # global std of all images + for i in range(images.shape[0]): # for each image in images + images[i, :, :, :] = (images[i, :, :, :] - mean) / std + elif strategy == 'per-image': + for i in range(images.shape[0]): # for each image in images + mean = np.mean(images[i, :, :, :]) # image mean + std = np.std(images[i, :, :, :]) # image std + images[i, :, :, :] = (images[i, :, :, :] - mean) / std + return images + + +def local_standardization(images, filename='normalization_data', + ndata=None, strategy='per-batch' + ) -> np.array: + """ + Standardize numpy array using local standardization. + :param images: numpy array in the format (n,w,h,c). + :param filename: filename to store mean and std data. + :param ndata: pandas df with mean and std values for each channel. + :param strategy: can select between per-image or per-batch. + :return: locally standardized numpy array + """ + if ndata: # for inference only + for i in range(images.shape[-1]): # for each channel in images + # standardize all images based on given mean and std + images[:, :, :, i] = \ + (images[:, :, :, i] - ndata['channel_mean'][i]) / \ + ndata['channel_std'][i] + return images + elif strategy == 'per-batch': # for all images in batch + f = open(filename + "_norm_data.csv", "w+") + f.write( + "i,channel_mean,channel_std,channel_mean_post,channel_std_post\n" + ) + for i in range(images.shape[-1]): # for each channel in images + channel_mean = np.mean(images[:, :, :, i]) # mean for each channel + channel_std = np.std(images[:, :, :, i]) # std for each channel + images[:, :, :, i] = \ + (images[:, :, :, i] - channel_mean) / channel_std + channel_mean_post = np.mean(images[:, :, :, i]) + channel_std_post = np.std(images[:, :, :, i]) + # write to file for each channel + f.write('{},{},{},{},{}\n'.format(i, channel_mean, channel_std, + channel_mean_post, + channel_std_post + ) + ) + f.close() # close file + elif strategy == 'per-image': # standardization for each image + for i in range(images.shape[0]): # for each image + for j in range(images.shape[-1]): # for each channel in images + channel_mean = np.mean(images[i, :, :, j]) + channel_std = np.std(images[i, :, :, j]) + images[i, :, :, j] = \ + (images[i, :, :, j] - channel_mean) / channel_std + else: + raise RuntimeError(f'Standardization <{strategy}> not supported') + + return images + + +def standardize_image( + image, + standardization_type: str, + mean: list = None, + std: list = None, + global_min: list = None, + global_max: list = None +): + """ + Standardize image within parameter, simple scaling of values. + Loca, Global, and Mixed options. + """ + image = image.astype(np.float32) + if standardization_type == 'local': + for i in range(image.shape[-1]): + image[:, :, i] = (image[:, :, i] - np.mean(image[:, :, i])) / \ + (np.std(image[:, :, i]) + 1e-8) + elif standardization_type == 'minmax': + for i in range(image.shape[-1]): + image[:, :, i] = (image[:, :, i] - 0) / (55-0) + elif standardization_type == 'localminmax': + for i in range(image.shape[-1]): + image[:, :, i] = (image[:, :, i] - np.min(image[:, :, 0])) / \ + (np.max(image[:, :, i])-np.min(image[:, :, i])) + elif standardization_type == 'globalminmax': + for i in range(image.shape[-1]): + image[:, :, i] = (image[:, :, i] - global_min) / \ + (global_max - global_min) + elif standardization_type == 'global': + for i in range(image.shape[-1]): + image[:, :, i] = (image[:, :, i] - mean[i]) / (std[i] + 1e-8) + elif standardization_type == 'mixed': + raise NotImplementedError + return image + + +def standardize_batch( + image_batch, + standardization_type: str, + mean: list = None, + std: list = None +): + """ + Standardize image within parameter, simple scaling of values. + Loca, Global, and Mixed options. + """ + for item in range(image_batch.shape[0]): + image_batch[item, :, :, :] = standardize_image( + image_batch[item, :, :, :], standardization_type, mean, std) + return image_batch + +# ------------------------ Data Preparation Functions ----------------------- # + + +def get_rand_patches_rand_cond(img, mask, n_patches=16000, sz=160, nclasses=6, + nodata_ascloud=True, method='rand' + ) -> np.array: + """ + Generate training data. + :param images: ndarray in the format (w,h,c). + :param mask: integer ndarray with shape (x_sz, y_sz) + :param n_patches: number of patches + :param sz: tile size, will be used for both height and width + :param nclasses: number of classes present in the output data + :param nodata_ascloud: convert no-data values to cloud labels + :param method: choose between rand, cond, cloud + rand - select N number of random patches for each image + cond - select N number of random patches for each image, + with the condition of having 1+ class per tile. + cloud - select tiles that have clouds + :return: two numpy array with data and labels. + """ + if nodata_ascloud: + # if no-data present, change to final class + mask = mask.values # return numpy array + mask[mask > nclasses] = nclasses # some no-data are 255 or other big + mask[mask < 0] = nclasses # some no-data are -128 or smaller negative + + patches = [] # list to store data patches + labels = [] # list to store label patches + + for i in tqdm(range(n_patches)): + + # Generate random integers from image + xc = random.randint(0, img.shape[0] - sz) + yc = random.randint(0, img.shape[1] - sz) + + if method == 'cond': + # while loop to regenerate random ints if tile has only one class + while len(np.unique(mask[xc:(xc+sz), yc:(yc+sz)])) == 1 or \ + 6 in mask[xc:(xc+sz), yc:(yc+sz)] or \ + img[xc:(xc+sz), yc:(yc+sz), :].values.min() < 0: + xc = random.randint(0, img.shape[0] - sz) + yc = random.randint(0, img.shape[1] - sz) + elif method == 'rand': + while 6 in mask[xc:(xc+sz), yc:(yc+sz)] or \ + img[xc:(xc+sz), yc:(yc+sz), :].values.min() < 0: + xc = random.randint(0, img.shape[0] - sz) + yc = random.randint(0, img.shape[1] - sz) + elif method == 'cloud': + while np.count_nonzero(mask[xc:(xc+sz), yc:(yc+sz)] == 6) < 15: + xc = random.randint(0, img.shape[0] - sz) + yc = random.randint(0, img.shape[1] - sz) + + # Generate img and mask patches + patch_img = img[xc:(xc + sz), yc:(yc + sz)] + patch_mask = mask[xc:(xc + sz), yc:(yc + sz)] + + # Apply some random transformations + random_transformation = np.random.randint(1, 7) + if random_transformation == 1: # flip left and right + patch_img = fliplr(patch_img) + patch_mask = fliplr(patch_mask) + elif random_transformation == 2: # reverse second dimension + patch_img = flipud(patch_img) + patch_mask = flipud(patch_mask) + elif random_transformation == 3: # rotate 90 degrees + patch_img = np.rot90(patch_img, 1) + patch_mask = np.rot90(patch_mask, 1) + elif random_transformation == 4: # rotate 180 degrees + patch_img = np.rot90(patch_img, 2) + patch_mask = np.rot90(patch_mask, 2) + elif random_transformation == 5: # rotate 270 degrees + patch_img = np.rot90(patch_img, 3) + patch_mask = np.rot90(patch_mask, 3) + else: # original image + pass + patches.append(patch_img) + labels.append(patch_mask) + return np.asarray(patches), np.asarray(labels) + + +def get_rand_patches_aug_augcond(img, mask, n_patches=16000, sz=256, + nclasses=6, over=50, nodata_ascloud=True, + nodata=-9999, method='augcond' + ) -> np.array: + """ + Generate training data. + :param images: ndarray in the format (w,h,c). + :param mask: integer ndarray with shape (x_sz, y_sz) + :param n_patches: number of patches + :param sz: tile size, will be used for both height and width + :param nclasses: number of classes present in the output data + :param over: number of pixels to overlap between images + :param nodata_ascloud: convert no-data values to cloud labels + :param method: choose between rand, cond, cloud + aug - select N * 8 number of random patches for each + image after data augmentation. + augcond - select N * 8 number of random patches for + each image, with the condition of having 1+ per + tile, after data augmentation. + :return: two numpy array with data and labels. + """ + mask = mask.values # return numpy array + + if nodata_ascloud: + # if no-data present, change to final class + mask[mask > nclasses] = nodata # some no-data are 255 or other big + mask[mask < 0] = nodata # some no-data are -128 or smaller negative + + patches = [] # list to store data patches + labels = [] # list to store label patches + + for i in tqdm(range(n_patches)): + + # Generate random integers from image + xc = random.randint(0, img.shape[0] - sz - sz) + yc = random.randint(0, img.shape[1] - sz - sz) + + if method == 'augcond': + # while loop to regenerate random ints if tile has only one class + while len(np.unique(mask[xc:(xc + sz), yc:(yc + sz)])) == 1 or \ + nodata in mask[xc:(xc + sz), yc:(yc + sz)] or \ + nodata in mask[(xc + sz - over):(xc + sz + sz - over), + (yc + sz - over):(yc + sz + sz - over)] or \ + nodata in mask[(xc + sz - over):(xc + sz + sz - over), + yc:(yc + sz)]: + xc = random.randint(0, img.shape[0] - sz - sz) + yc = random.randint(0, img.shape[1] - sz - sz) + elif method == 'aug': + # while loop to regenerate random ints if tile has only one class + while nodata in mask[xc:(xc + sz), yc:(yc + sz)] or \ + nodata in mask[(xc + sz - over):(xc + sz + sz - over), + (yc + sz - over):(yc + sz + sz - over)] or \ + nodata in mask[(xc + sz - over):(xc + sz + sz - over), + yc:(yc + sz)]: + xc = random.randint(0, img.shape[0] - sz - sz) + yc = random.randint(0, img.shape[1] - sz - sz) + + # Generate img and mask patches + patch_img = img[xc:(xc + sz), yc:(yc + sz)] # original image patch + patch_mask = mask[xc:(xc + sz), yc:(yc + sz)] # original mask patch + + # Apply transformations for data augmentation + # 1. No augmentation and append to list + patches.append(patch_img) + labels.append(patch_mask) + + # 2. Rotate 90 and append to list + patches.append(np.rot90(patch_img, 1)) + labels.append(np.rot90(patch_mask, 1)) + + # 3. Rotate 180 and append to list + patches.append(np.rot90(patch_img, 2)) + labels.append(np.rot90(patch_mask, 2)) + + # 4. Rotate 270 + patches.append(np.rot90(patch_img, 3)) + labels.append(np.rot90(patch_mask, 3)) + + # 5. Flipped up and down’ + patches.append(flipud(patch_img)) + labels.append(flipud(patch_mask)) + + # 6. Flipped left and right + patches.append(fliplr(patch_img)) + labels.append(fliplr(patch_mask)) + + # 7. overlapping tiles - next tile, down + patches.append(img[(xc + sz - over):(xc + sz + sz - over), + (yc + sz - over):(yc + sz + sz - over)]) + labels.append(mask[(xc + sz - over):(xc + sz + sz - over), + (yc + sz - over):(yc + sz + sz - over)]) + + # 8. overlapping tiles - next tile, side + patches.append(img[(xc + sz - over):(xc + sz + sz - over), + yc:(yc + sz)]) + labels.append(mask[(xc + sz - over):(xc + sz + sz - over), + yc:(yc + sz)]) + return np.asarray(patches), np.asarray(labels) + + +# ------------------------ Artifact Removal Functions ----------------------- # + +def _2d_spline(window_size=128, power=2) -> np.array: + """ + Window method for boundaries/edge artifacts smoothing. + :param window_size: size of window/tile to smooth + :param power: spline polinomial power to use + :return: smoothing distribution numpy array + """ + intersection = int(window_size/4) + tria = scipy.signal.triang(window_size) + wind_outer = (abs(2*(tria)) ** power)/2 + wind_outer[intersection:-intersection] = 0 + + wind_inner = 1 - (abs(2*(tria - 1)) ** power)/2 + wind_inner[:intersection] = 0 + wind_inner[-intersection:] = 0 + + wind = wind_inner + wind_outer + wind = wind / np.average(wind) + wind = np.expand_dims(np.expand_dims(wind, 1), 2) + wind = wind * wind.transpose(1, 0, 2) + return wind + + +def _hann_matrix(window_size=128, power=2) -> np.array: + logging.info("Placeholder for next release.") + + +# ------------------------------------------------------------------------------- +# module preprocessing Unit Tests +# ------------------------------------------------------------------------------- +if __name__ == "__main__": + + logging.basicConfig(level=logging.INFO) + + # Unit Test #1 - Testing normalization distributions + x = (np.random.randint(65536, size=(10, 128, 128, 6))).astype('float32') + x_norm = normalize(x, factor=65535) # apply static normalization + assert x_norm.max() == 1.0, "Unexpected max value." + logging.info(f"UT #1 PASS: {x_norm.mean()}, {x_norm.std()}") + + # Unit Test #2 - Testing standardization distributions + standardized = global_standardization(x_norm, strategy='per-batch') + assert standardized.max() > 1.731, "Unexpected max value." + logging.info(f"UT #2 PASS: {standardized.mean()}, {standardized.std()}") + + # Unit Test #3 - Testing standardization distributions + standardized = global_standardization(x_norm, strategy='per-image') + assert standardized.max() > 1.73, "Unexpected max value." + logging.info(f"UT #3 PASS: {standardized.mean()}, {standardized.std()}") + + # Unit Test #4 - Testing standardization distributions + standardized = local_standardization(x_norm, filename='normalization_data', + strategy='per-batch' + ) + assert standardized.max() > 1.74, "Unexpected max value." + logging.info(f"UT #4 PASS: {standardized.mean()}, {standardized.std()}") + + # Unit Test #5 - Testing standardization distributions + standardized = local_standardization(x_norm, filename='normalization_data', + strategy='per-image' + ) + assert standardized.max() > 1.75, "Unexpected max value." + logging.info(f"UT #5 PASS: {standardized.mean()}, {standardized.std()}") diff --git a/pytorch-caney/pytorch_caney/ptc_logging.py b/pytorch-caney/pytorch_caney/ptc_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..3b764620100b9a0aaeae5d9692cceaf6f7852678 --- /dev/null +++ b/pytorch-caney/pytorch_caney/ptc_logging.py @@ -0,0 +1,49 @@ +import os +import sys +import logging +import functools +from termcolor import colored + + +@functools.lru_cache() +def create_logger(output_dir, dist_rank=0, name=''): + # create logger + logger = logging.getLogger(name) + + logger.setLevel(logging.DEBUG) + + logger.propagate = False + + # create formatter + fmt = '[%(asctime)s %(name)s] ' + \ + '(%(filename)s %(lineno)d): ' + \ + '%(levelname)s %(message)s' + + color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ + colored('(%(filename)s %(lineno)d)', 'yellow') + \ + ': %(levelname)s %(message)s' + + # create console handlers for master process + if dist_rank == 0: + + console_handler = logging.StreamHandler(sys.stdout) + + console_handler.setLevel(logging.DEBUG) + + console_handler.setFormatter( + logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) + + logger.addHandler(console_handler) + + # create file handlers + file_handler = logging.FileHandler(os.path.join( + output_dir, f'log_rank{dist_rank}.txt'), mode='a') + + file_handler.setLevel(logging.DEBUG) + + file_handler.setFormatter(logging.Formatter( + fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) + + logger.addHandler(file_handler) + + return logger diff --git a/pytorch-caney/pytorch_caney/tests/config/test_config.yaml b/pytorch-caney/pytorch_caney/tests/config/test_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c13ee63a69fac8965deb6991556ca6e4c4fc4517 --- /dev/null +++ b/pytorch-caney/pytorch_caney/tests/config/test_config.yaml @@ -0,0 +1,27 @@ +MODEL: + TYPE: swinv2 + NAME: test_config + DROP_PATH_RATE: 0.1 + SWINV2: + IN_CHANS: 7 + EMBED_DIM: 128 + DEPTHS: [ 2, 2, 18, 2 ] + NUM_HEADS: [ 4, 8, 16, 32 ] + WINDOW_SIZE: 12 +DATA: + IMG_SIZE: 192 + MASK_PATCH_SIZE: 32 + MASK_RATIO: 0.6 +TRAIN: + EPOCHS: 800 + WARMUP_EPOCHS: 10 + BASE_LR: 1e-4 + WARMUP_LR: 5e-7 + WEIGHT_DECAY: 0.05 + LR_SCHEDULER: + NAME: 'multistep' + GAMMA: 0.1 + MULTISTEPS: [700,] +PRINT_FREQ: 100 +SAVE_FREQ: 5 +TAG: test_config_tag \ No newline at end of file diff --git a/pytorch-caney/pytorch_caney/tests/test_build.py b/pytorch-caney/pytorch_caney/tests/test_build.py new file mode 100644 index 0000000000000000000000000000000000000000..a4728824995dfd851aba53470126caec62572bcd --- /dev/null +++ b/pytorch-caney/pytorch_caney/tests/test_build.py @@ -0,0 +1,50 @@ +from pytorch_caney.models.build import build_model +from pytorch_caney.config import get_config + +import unittest +import argparse +import logging + + +class TestBuildModel(unittest.TestCase): + + def setUp(self): + # Initialize any required configuration here + config_path = 'pytorch_caney/' + \ + 'tests/config/test_config.yaml' + args = argparse.Namespace(cfg=config_path) + self.config = get_config(args) + self.logger = logging.getLogger("TestLogger") + self.logger.setLevel(logging.DEBUG) + + def test_build_mim_model(self): + _ = build_model(self.config, + pretrain=True, + pretrain_method='mim', + logger=self.logger) + # Add assertions here to validate the returned 'model' instance + # For example: self.assertIsInstance(model, YourMimModelClass) + + def test_build_swinv2_encoder(self): + _ = build_model(self.config, logger=self.logger) + # Add assertions here to validate the returned 'model' instance + # For example: self.assertIsInstance(model, SwinTransformerV2) + + def test_build_unet_decoder(self): + self.config.defrost() + self.config.MODEL.DECODER = 'unet' + self.config.freeze() + _ = build_model(self.config, logger=self.logger) + # Add assertions here to validate the returned 'model' instance + # For example: self.assertIsInstance(model, YourUnetSwinModelClass) + + def test_unknown_decoder_architecture(self): + self.config.defrost() + self.config.MODEL.DECODER = 'unknown_decoder' + self.config.freeze() + with self.assertRaises(NotImplementedError): + build_model(self.config, logger=self.logger) + + +if __name__ == '__main__': + unittest.main() diff --git a/pytorch-caney/pytorch_caney/tests/test_config.py b/pytorch-caney/pytorch_caney/tests/test_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f75c5349c01b8b7ae3165e3c241dcf2957967f50 --- /dev/null +++ b/pytorch-caney/pytorch_caney/tests/test_config.py @@ -0,0 +1,44 @@ +from pytorch_caney.config import get_config + +import argparse +import unittest + + +class TestConfig(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.config_yaml_path = 'pytorch_caney/' + \ + 'tests/config/test_config.yaml' + + def test_default_config(self): + # Get the default configuration + args = argparse.Namespace(cfg=self.config_yaml_path) + config = get_config(args) + + # Test specific configuration values + self.assertEqual(config.DATA.BATCH_SIZE, 128) + self.assertEqual(config.DATA.DATASET, 'MODIS') + self.assertEqual(config.MODEL.TYPE, 'swinv2') + self.assertEqual(config.MODEL.NAME, 'test_config') + self.assertEqual(config.TRAIN.EPOCHS, 800) + + def test_custom_config(self): + # Test with custom arguments + args = argparse.Namespace( + cfg=self.config_yaml_path, + batch_size=64, + dataset='CustomDataset', + data_paths=['solongandthanksforallthefish'], + ) + config = get_config(args) + + # Test specific configuration values with custom arguments + self.assertEqual(config.DATA.BATCH_SIZE, 64) + self.assertEqual(config.DATA.DATASET, 'CustomDataset') + self.assertEqual(config.DATA.DATA_PATHS, + ['solongandthanksforallthefish']) + + +if __name__ == '__main__': + unittest.main() diff --git a/pytorch-caney/pytorch_caney/tests/test_data.py b/pytorch-caney/pytorch_caney/tests/test_data.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a5852d936a8841643eddc6e5000769f82f7357 --- /dev/null +++ b/pytorch-caney/pytorch_caney/tests/test_data.py @@ -0,0 +1,38 @@ +from pytorch_caney.data.datamodules.finetune_datamodule \ + import get_dataset_from_dict + +from pytorch_caney.data.datamodules.finetune_datamodule \ + import DATASETS + +import unittest + + +class TestGetDatasetFromDict(unittest.TestCase): + + def test_existing_datasets(self): + # Test existing datasets + for dataset_name in ['modis', 'modislc9', 'modislc5']: + dataset = get_dataset_from_dict(dataset_name) + self.assertIsNotNone(dataset) + + def test_non_existing_dataset(self): + # Test non-existing dataset + invalid_dataset_name = 'invalid_dataset' + with self.assertRaises(KeyError) as context: + get_dataset_from_dict(invalid_dataset_name) + expected_error_msg = f'"{invalid_dataset_name} ' + \ + 'is not an existing dataset. Available datasets:' + \ + f' {DATASETS.keys()}"' + self.assertEqual(str(context.exception), expected_error_msg) + + def test_dataset_name_case_insensitive(self): + # Test case insensitivity + dataset_name = 'MoDiSLC5' + dataset = get_dataset_from_dict(dataset_name) + self.assertIsNotNone(dataset) + +# Add more test cases as needed + + +if __name__ == '__main__': + unittest.main() diff --git a/pytorch-caney/pytorch_caney/tests/test_loss_utils.py b/pytorch-caney/pytorch_caney/tests/test_loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..74a256a34e179ea093c4e4a950f5f093fab3663a --- /dev/null +++ b/pytorch-caney/pytorch_caney/tests/test_loss_utils.py @@ -0,0 +1,46 @@ +from pytorch_caney.loss.utils import to_tensor + +import unittest +import numpy as np +import torch + + +class TestToTensorFunction(unittest.TestCase): + + def test_tensor_input(self): + tensor = torch.tensor([1, 2, 3]) + result = to_tensor(tensor) + self.assertTrue(torch.equal(result, tensor)) + + def test_tensor_input_with_dtype(self): + tensor = torch.tensor([1, 2, 3]) + result = to_tensor(tensor, dtype=torch.float32) + self.assertTrue(torch.equal(result, tensor.float())) + + def test_numpy_array_input(self): + numpy_array = np.array([1, 2, 3]) + expected_tensor = torch.tensor([1, 2, 3]) + result = to_tensor(numpy_array) + self.assertTrue(torch.equal(result, expected_tensor)) + + def test_numpy_array_input_with_dtype(self): + numpy_array = np.array([1, 2, 3]) + expected_tensor = torch.tensor([1, 2, 3], dtype=torch.float32) + result = to_tensor(numpy_array, dtype=torch.float32) + self.assertTrue(torch.equal(result, expected_tensor)) + + def test_list_input(self): + input_list = [1, 2, 3] + expected_tensor = torch.tensor([1, 2, 3]) + result = to_tensor(input_list) + self.assertTrue(torch.equal(result, expected_tensor)) + + def test_list_input_with_dtype(self): + input_list = [1, 2, 3] + expected_tensor = torch.tensor([1, 2, 3], dtype=torch.float32) + result = to_tensor(input_list, dtype=torch.float32) + self.assertTrue(torch.equal(result, expected_tensor)) + + +if __name__ == '__main__': + unittest.main() diff --git a/pytorch-caney/pytorch_caney/tests/test_lr_scheduler.py b/pytorch-caney/pytorch_caney/tests/test_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..f0cd7f2d2c92a8c60ee9204d396d9741abcc7ad7 --- /dev/null +++ b/pytorch-caney/pytorch_caney/tests/test_lr_scheduler.py @@ -0,0 +1,48 @@ +from pytorch_caney.lr_scheduler import build_scheduler + +import unittest +from unittest.mock import Mock, patch + + +class TestBuildScheduler(unittest.TestCase): + def setUp(self): + self.config = Mock( + TRAIN=Mock( + EPOCHS=300, + WARMUP_EPOCHS=20, + MIN_LR=1e-6, + WARMUP_LR=1e-7, + LR_SCHEDULER=Mock( + NAME='cosine', + DECAY_EPOCHS=30, + DECAY_RATE=0.1, + MULTISTEPS=[50, 100], + GAMMA=0.1 + ) + ) + ) + + self.optimizer = Mock() + self.n_iter_per_epoch = 100 # Example value + + def test_build_cosine_scheduler(self): + with patch('pytorch_caney.lr_scheduler.CosineLRScheduler') \ + as mock_cosine_scheduler: + _ = build_scheduler(self.config, + self.optimizer, + self.n_iter_per_epoch) + + mock_cosine_scheduler.assert_called_once_with( + self.optimizer, + t_initial=300 * 100, + cycle_mul=1., + lr_min=1e-6, + warmup_lr_init=1e-7, + warmup_t=20 * 100, + cycle_limit=1, + t_in_epochs=False + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/pytorch-caney/pytorch_caney/tests/test_transforms.py b/pytorch-caney/pytorch_caney/tests/test_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..9656e0b37947bcc3b20023d54f821803591d6f68 --- /dev/null +++ b/pytorch-caney/pytorch_caney/tests/test_transforms.py @@ -0,0 +1,70 @@ +from pytorch_caney.config import get_config +from pytorch_caney.data.transforms import SimmimTransform +from pytorch_caney.data.transforms import TensorResizeTransform + +import argparse +import unittest +import torch +import numpy as np + + +class TestTransforms(unittest.TestCase): + + def setUp(self): + # Initialize any required configuration here + config_path = 'pytorch_caney/' + \ + 'tests/config/test_config.yaml' + args = argparse.Namespace(cfg=config_path) + self.config = get_config(args) + + def test_simmim_transform(self): + + # Create an instance of SimmimTransform + transform = SimmimTransform(self.config) + + # Create a sample ndarray + img = np.random.randn(self.config.DATA.IMG_SIZE, + self.config.DATA.IMG_SIZE, + 7) + + # Apply the transform + img_transformed, mask = transform(img) + + # Assertions + self.assertIsInstance(img_transformed, torch.Tensor) + self.assertEqual(img_transformed.shape, (7, + self.config.DATA.IMG_SIZE, + self.config.DATA.IMG_SIZE)) + self.assertIsInstance(mask, np.ndarray) + + def test_tensor_resize_transform(self): + # Create an instance of TensorResizeTransform + transform = TensorResizeTransform(self.config) + + # Create a sample image tensor + img = np.random.randn(self.config.DATA.IMG_SIZE, + self.config.DATA.IMG_SIZE, + 7) + + target = np.random.randint(0, 5, + size=((self.config.DATA.IMG_SIZE, + self.config.DATA.IMG_SIZE))) + + # Apply the transform + img_transformed = transform(img) + target_transformed = transform(target) + + # Assertions + self.assertIsInstance(img_transformed, torch.Tensor) + self.assertEqual(img_transformed.shape, + (7, self.config.DATA.IMG_SIZE, + self.config.DATA.IMG_SIZE)) + + self.assertIsInstance(target_transformed, torch.Tensor) + self.assertEqual(target_transformed.shape, + (1, self.config.DATA.IMG_SIZE, + self.config.DATA.IMG_SIZE)) + + +if __name__ == '__main__': + unittest.main() diff --git a/pytorch-caney/pytorch_caney/training/fine_tuning.py b/pytorch-caney/pytorch_caney/training/fine_tuning.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/training/mim_utils.py b/pytorch-caney/pytorch_caney/training/mim_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..949f307d95a9ce041e3f9a860a0ab2d09ff5640e --- /dev/null +++ b/pytorch-caney/pytorch_caney/training/mim_utils.py @@ -0,0 +1,706 @@ +from functools import partial +from torch import optim as optim + +import os +import torch +import torch.distributed as dist +import numpy as np +from scipy import interpolate + + +def build_optimizer(config, model, is_pretrain=False, logger=None): + """ + Build optimizer, set weight decay of normalization to 0 by default. + AdamW only. + """ + logger.info('>>>>>>>>>> Build Optimizer') + + skip = {} + + skip_keywords = {} + + if hasattr(model, 'no_weight_decay'): + skip = model.no_weight_decay() + + if hasattr(model, 'no_weight_decay_keywords'): + skip_keywords = model.no_weight_decay_keywords() + + if is_pretrain: + parameters = get_pretrain_param_groups(model, skip, skip_keywords) + + else: + + depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' \ + else config.MODEL.SWINV2.DEPTHS + + num_layers = sum(depths) + + get_layer_func = partial(get_swin_layer, + num_layers=num_layers + 2, + depths=depths) + + scales = list(config.TRAIN.LAYER_DECAY ** i for i in + reversed(range(num_layers + 2))) + + parameters = get_finetune_param_groups(model, + config.TRAIN.BASE_LR, + config.TRAIN.WEIGHT_DECAY, + get_layer_func, + scales, + skip, + skip_keywords) + + optimizer = None + + optimizer = optim.AdamW(parameters, + eps=config.TRAIN.OPTIMIZER.EPS, + betas=config.TRAIN.OPTIMIZER.BETAS, + lr=config.TRAIN.BASE_LR, + weight_decay=config.TRAIN.WEIGHT_DECAY) + + logger.info(optimizer) + + return optimizer + + +def set_weight_decay(model, skip_list=(), skip_keywords=()): + """ + + Args: + model (_type_): _description_ + skip_list (tuple, optional): _description_. Defaults to (). + skip_keywords (tuple, optional): _description_. Defaults to (). + + Returns: + _type_: _description_ + """ + + has_decay = [] + + no_decay = [] + + for name, param in model.named_parameters(): + + if not param.requires_grad: + + continue # frozen weights + + if len(param.shape) == 1 or name.endswith(".bias") \ + or (name in skip_list) or \ + check_keywords_in_name(name, skip_keywords): + + no_decay.append(param) + + else: + + has_decay.append(param) + + return [{'params': has_decay}, + {'params': no_decay, 'weight_decay': 0.}] + + +def check_keywords_in_name(name, keywords=()): + + isin = False + + for keyword in keywords: + + if keyword in name: + + isin = True + + return isin + + +def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): + + has_decay = [] + + no_decay = [] + + has_decay_name = [] + + no_decay_name = [] + + for name, param in model.named_parameters(): + + if not param.requires_grad: + + continue + + if len(param.shape) == 1 or name.endswith(".bias") or \ + (name in skip_list) or \ + check_keywords_in_name(name, skip_keywords): + + no_decay.append(param) + + no_decay_name.append(name) + + else: + + has_decay.append(param) + + has_decay_name.append(name) + + return [{'params': has_decay}, + {'params': no_decay, 'weight_decay': 0.}] + + +def get_swin_layer(name, num_layers, depths): + + if name in ("mask_token"): + + return 0 + + elif name.startswith("patch_embed"): + + return 0 + + elif name.startswith("layers"): + + layer_id = int(name.split('.')[1]) + + block_id = name.split('.')[3] + + if block_id == 'reduction' or block_id == 'norm': + + return sum(depths[:layer_id + 1]) + + layer_id = sum(depths[:layer_id]) + int(block_id) + + return layer_id + 1 + + else: + + return num_layers - 1 + + +def get_finetune_param_groups(model, + lr, + weight_decay, + get_layer_func, + scales, + skip_list=(), + skip_keywords=()): + + parameter_group_names = {} + + parameter_group_vars = {} + + for name, param in model.named_parameters(): + + if not param.requires_grad: + + continue + + if len(param.shape) == 1 or name.endswith(".bias") \ + or (name in skip_list) or \ + check_keywords_in_name(name, skip_keywords): + + group_name = "no_decay" + + this_weight_decay = 0. + + else: + + group_name = "decay" + + this_weight_decay = weight_decay + + if get_layer_func is not None: + + layer_id = get_layer_func(name) + + group_name = "layer_%d_%s" % (layer_id, group_name) + + else: + + layer_id = None + + if group_name not in parameter_group_names: + + if scales is not None: + + scale = scales[layer_id] + + else: + + scale = 1. + + parameter_group_names[group_name] = { + "group_name": group_name, + "weight_decay": this_weight_decay, + "params": [], + "lr": lr * scale, + "lr_scale": scale, + } + + parameter_group_vars[group_name] = { + "group_name": group_name, + "weight_decay": this_weight_decay, + "params": [], + "lr": lr * scale, + "lr_scale": scale + } + + parameter_group_vars[group_name]["params"].append(param) + + parameter_group_names[group_name]["params"].append(name) + + return list(parameter_group_vars.values()) + + +def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger): + + logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........") + + if config.MODEL.RESUME.startswith('https'): + + checkpoint = torch.hub.load_state_dict_from_url( + config.MODEL.RESUME, map_location='cpu', check_hash=True) + + else: + + checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') + + # re-map keys due to name change (only for loading provided models) + rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k] + + for k in rpe_mlp_keys: + + checkpoint['model'][k.replace( + 'rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k) + + msg = model.load_state_dict(checkpoint['model'], strict=False) + + logger.info(msg) + + max_accuracy = 0.0 + + if not config.EVAL_MODE and 'optimizer' in checkpoint \ + and 'lr_scheduler' in checkpoint \ + and 'scaler' in checkpoint \ + and 'epoch' in checkpoint: + + optimizer.load_state_dict(checkpoint['optimizer']) + + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + + scaler.load_state_dict(checkpoint['scaler']) + + config.defrost() + config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 + config.freeze() + + logger.info( + f"=> loaded successfully '{config.MODEL.RESUME}' " + + f"(epoch {checkpoint['epoch']})") + + if 'max_accuracy' in checkpoint: + max_accuracy = checkpoint['max_accuracy'] + + else: + max_accuracy = 0.0 + + del checkpoint + + torch.cuda.empty_cache() + + return max_accuracy + + +def save_checkpoint(config, epoch, model, max_accuracy, + optimizer, lr_scheduler, scaler, logger): + + save_state = {'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'scaler': scaler.state_dict(), + 'max_accuracy': max_accuracy, + 'epoch': epoch, + 'config': config} + + save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') + + logger.info(f"{save_path} saving......") + + torch.save(save_state, save_path) + + logger.info(f"{save_path} saved !!!") + + +def get_grad_norm(parameters, norm_type=2): + + if isinstance(parameters, torch.Tensor): + + parameters = [parameters] + + parameters = list(filter(lambda p: p.grad is not None, parameters)) + + norm_type = float(norm_type) + + total_norm = 0 + + for p in parameters: + + param_norm = p.grad.data.norm(norm_type) + + total_norm += param_norm.item() ** norm_type + + total_norm = total_norm ** (1. / norm_type) + + return total_norm + + +def auto_resume_helper(output_dir, logger): + + checkpoints = os.listdir(output_dir) + + checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] + + logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") + + if len(checkpoints) > 0: + + latest_checkpoint = max([os.path.join(output_dir, d) + for d in checkpoints], key=os.path.getmtime) + + logger.info(f"The latest checkpoint founded: {latest_checkpoint}") + + resume_file = latest_checkpoint + + else: + + resume_file = None + + return resume_file + + +def reduce_tensor(tensor): + + rt = tensor.clone() + + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + + rt /= dist.get_world_size() + + return rt + + +def load_pretrained(config, model, logger): + + logger.info( + f">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........") + + checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') + + checkpoint_model = checkpoint['model'] + + if any([True if 'encoder.' in k else + False for k in checkpoint_model.keys()]): + + checkpoint_model = {k.replace( + 'encoder.', ''): v for k, v in checkpoint_model.items() + if k.startswith('encoder.')} + + logger.info('Detect pre-trained model, remove [encoder.] prefix.') + + else: + + logger.info( + 'Detect non-pre-trained model, pass without doing anything.') + + if config.MODEL.TYPE in ['swin', 'swinv2']: + + logger.info( + ">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") + + checkpoint = remap_pretrained_keys_swin( + model, checkpoint_model, logger) + + else: + + raise NotImplementedError + + msg = model.load_state_dict(checkpoint_model, strict=False) + + logger.info(msg) + + del checkpoint + + torch.cuda.empty_cache() + + logger.info(f">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'") + + +def remap_pretrained_keys_swin(model, checkpoint_model, logger): + + state_dict = model.state_dict() + + # Geometric interpolation when pre-trained patch size mismatch + # with fine-tuned patch size + all_keys = list(checkpoint_model.keys()) + + for key in all_keys: + + if "relative_position_bias_table" in key: + + logger.info(f"Key: {key}") + + rel_position_bias_table_pretrained = checkpoint_model[key] + + rel_position_bias_table_current = state_dict[key] + + L1, nH1 = rel_position_bias_table_pretrained.size() + + L2, nH2 = rel_position_bias_table_current.size() + + if nH1 != nH2: + logger.info(f"Error in loading {key}, passing......") + + else: + + if L1 != L2: + + logger.info( + f"{key}: Interpolate " + + "relative_position_bias_table using geo.") + + src_size = int(L1 ** 0.5) + + dst_size = int(L2 ** 0.5) + + def geometric_progression(a, r, n): + return a * (1.0 - r ** n) / (1.0 - r) + + left, right = 1.01, 1.5 + + while right - left > 1e-6: + + q = (left + right) / 2.0 + + gp = geometric_progression(1, q, src_size // 2) + + if gp > dst_size // 2: + + right = q + + else: + + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + + cur = 1 + + for i in range(src_size // 2): + + dis.append(cur) + + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + + y = r_ids + [0] + dis + + t = dst_size // 2.0 + + dx = np.arange(-t, t + 0.1, 1.0) + + dy = np.arange(-t, t + 0.1, 1.0) + + logger.info("Original positions = %s" % str(x)) + + logger.info("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(nH1): + + z = rel_position_bias_table_pretrained[:, i].view( + src_size, src_size).float().numpy() + + f_cubic = interpolate.interp2d(x, y, z, kind='cubic') + + all_rel_pos_bias_host = \ + torch.Tensor(f_cubic(dx, dy) + ).contiguous().view(-1, 1) + + all_rel_pos_bias.append( + all_rel_pos_bias_host.to( + rel_position_bias_table_pretrained.device)) + + new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + checkpoint_model[key] = new_rel_pos_bias + + # delete relative_position_index since we always re-init it + relative_position_index_keys = [ + k for k in checkpoint_model.keys() if "relative_position_index" in k] + + for k in relative_position_index_keys: + + del checkpoint_model[k] + + # delete relative_coords_table since we always re-init it + relative_coords_table_keys = [ + k for k in checkpoint_model.keys() if "relative_coords_table" in k] + + for k in relative_coords_table_keys: + + del checkpoint_model[k] + + # delete attn_mask since we always re-init it + attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] + + for k in attn_mask_keys: + + del checkpoint_model[k] + + return checkpoint_model + + +def remap_pretrained_keys_vit(model, checkpoint_model, logger): + + # Duplicate shared rel_pos_bias to each layer + if getattr(model, 'use_rel_pos_bias', False) and \ + "rel_pos_bias.relative_position_bias_table" in checkpoint_model: + + logger.info( + "Expand the shared relative position " + + "embedding to each transformer block.") + + num_layers = model.get_num_layers() + + rel_pos_bias = \ + checkpoint_model["rel_pos_bias.relative_position_bias_table"] + + for i in range(num_layers): + + checkpoint_model["blocks.%d.attn.relative_position_bias_table" % + i] = rel_pos_bias.clone() + + checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") + + # Geometric interpolation when pre-trained patch + # size mismatch with fine-tuned patch size + all_keys = list(checkpoint_model.keys()) + + for key in all_keys: + + if "relative_position_index" in key: + + checkpoint_model.pop(key) + + if "relative_position_bias_table" in key: + + rel_pos_bias = checkpoint_model[key] + + src_num_pos, num_attn_heads = rel_pos_bias.size() + + dst_num_pos, _ = model.state_dict()[key].size() + + dst_patch_shape = model.patch_embed.patch_shape + + if dst_patch_shape[0] != dst_patch_shape[1]: + + raise NotImplementedError() + + num_extra_tokens = dst_num_pos - \ + (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) + + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + + dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) + + if src_size != dst_size: + + logger.info("Position interpolate for " + + "%s from %dx%d to %dx%d" % ( + key, + src_size, + src_size, + dst_size, + dst_size)) + + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + + return a * (1.0 - r ** n) / (1.0 - r) + + left, right = 1.01, 1.5 + + while right - left > 1e-6: + + q = (left + right) / 2.0 + + gp = geometric_progression(1, q, src_size // 2) + + if gp > dst_size // 2: + + right = q + + else: + + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + + cur = 1 + + for i in range(src_size // 2): + + dis.append(cur) + + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + + y = r_ids + [0] + dis + + t = dst_size // 2.0 + + dx = np.arange(-t, t + 0.1, 1.0) + + dy = np.arange(-t, t + 0.1, 1.0) + + logger.info("Original positions = %s" % str(x)) + + logger.info("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(num_attn_heads): + + z = rel_pos_bias[:, i].view( + src_size, src_size).float().numpy() + + f = interpolate.interp2d(x, y, z, kind='cubic') + + all_rel_pos_bias_host = \ + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1) + + all_rel_pos_bias.append( + all_rel_pos_bias_host.to(rel_pos_bias.device)) + + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + new_rel_pos_bias = torch.cat( + (rel_pos_bias, extra_tokens), dim=0) + + checkpoint_model[key] = new_rel_pos_bias + + return checkpoint_model diff --git a/pytorch-caney/pytorch_caney/training/pre_training.py b/pytorch-caney/pytorch_caney/training/pre_training.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/training/simmim_utils.py b/pytorch-caney/pytorch_caney/training/simmim_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..949f307d95a9ce041e3f9a860a0ab2d09ff5640e --- /dev/null +++ b/pytorch-caney/pytorch_caney/training/simmim_utils.py @@ -0,0 +1,706 @@ +from functools import partial +from torch import optim as optim + +import os +import torch +import torch.distributed as dist +import numpy as np +from scipy import interpolate + + +def build_optimizer(config, model, is_pretrain=False, logger=None): + """ + Build optimizer, set weight decay of normalization to 0 by default. + AdamW only. + """ + logger.info('>>>>>>>>>> Build Optimizer') + + skip = {} + + skip_keywords = {} + + if hasattr(model, 'no_weight_decay'): + skip = model.no_weight_decay() + + if hasattr(model, 'no_weight_decay_keywords'): + skip_keywords = model.no_weight_decay_keywords() + + if is_pretrain: + parameters = get_pretrain_param_groups(model, skip, skip_keywords) + + else: + + depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' \ + else config.MODEL.SWINV2.DEPTHS + + num_layers = sum(depths) + + get_layer_func = partial(get_swin_layer, + num_layers=num_layers + 2, + depths=depths) + + scales = list(config.TRAIN.LAYER_DECAY ** i for i in + reversed(range(num_layers + 2))) + + parameters = get_finetune_param_groups(model, + config.TRAIN.BASE_LR, + config.TRAIN.WEIGHT_DECAY, + get_layer_func, + scales, + skip, + skip_keywords) + + optimizer = None + + optimizer = optim.AdamW(parameters, + eps=config.TRAIN.OPTIMIZER.EPS, + betas=config.TRAIN.OPTIMIZER.BETAS, + lr=config.TRAIN.BASE_LR, + weight_decay=config.TRAIN.WEIGHT_DECAY) + + logger.info(optimizer) + + return optimizer + + +def set_weight_decay(model, skip_list=(), skip_keywords=()): + """ + + Args: + model (_type_): _description_ + skip_list (tuple, optional): _description_. Defaults to (). + skip_keywords (tuple, optional): _description_. Defaults to (). + + Returns: + _type_: _description_ + """ + + has_decay = [] + + no_decay = [] + + for name, param in model.named_parameters(): + + if not param.requires_grad: + + continue # frozen weights + + if len(param.shape) == 1 or name.endswith(".bias") \ + or (name in skip_list) or \ + check_keywords_in_name(name, skip_keywords): + + no_decay.append(param) + + else: + + has_decay.append(param) + + return [{'params': has_decay}, + {'params': no_decay, 'weight_decay': 0.}] + + +def check_keywords_in_name(name, keywords=()): + + isin = False + + for keyword in keywords: + + if keyword in name: + + isin = True + + return isin + + +def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): + + has_decay = [] + + no_decay = [] + + has_decay_name = [] + + no_decay_name = [] + + for name, param in model.named_parameters(): + + if not param.requires_grad: + + continue + + if len(param.shape) == 1 or name.endswith(".bias") or \ + (name in skip_list) or \ + check_keywords_in_name(name, skip_keywords): + + no_decay.append(param) + + no_decay_name.append(name) + + else: + + has_decay.append(param) + + has_decay_name.append(name) + + return [{'params': has_decay}, + {'params': no_decay, 'weight_decay': 0.}] + + +def get_swin_layer(name, num_layers, depths): + + if name in ("mask_token"): + + return 0 + + elif name.startswith("patch_embed"): + + return 0 + + elif name.startswith("layers"): + + layer_id = int(name.split('.')[1]) + + block_id = name.split('.')[3] + + if block_id == 'reduction' or block_id == 'norm': + + return sum(depths[:layer_id + 1]) + + layer_id = sum(depths[:layer_id]) + int(block_id) + + return layer_id + 1 + + else: + + return num_layers - 1 + + +def get_finetune_param_groups(model, + lr, + weight_decay, + get_layer_func, + scales, + skip_list=(), + skip_keywords=()): + + parameter_group_names = {} + + parameter_group_vars = {} + + for name, param in model.named_parameters(): + + if not param.requires_grad: + + continue + + if len(param.shape) == 1 or name.endswith(".bias") \ + or (name in skip_list) or \ + check_keywords_in_name(name, skip_keywords): + + group_name = "no_decay" + + this_weight_decay = 0. + + else: + + group_name = "decay" + + this_weight_decay = weight_decay + + if get_layer_func is not None: + + layer_id = get_layer_func(name) + + group_name = "layer_%d_%s" % (layer_id, group_name) + + else: + + layer_id = None + + if group_name not in parameter_group_names: + + if scales is not None: + + scale = scales[layer_id] + + else: + + scale = 1. + + parameter_group_names[group_name] = { + "group_name": group_name, + "weight_decay": this_weight_decay, + "params": [], + "lr": lr * scale, + "lr_scale": scale, + } + + parameter_group_vars[group_name] = { + "group_name": group_name, + "weight_decay": this_weight_decay, + "params": [], + "lr": lr * scale, + "lr_scale": scale + } + + parameter_group_vars[group_name]["params"].append(param) + + parameter_group_names[group_name]["params"].append(name) + + return list(parameter_group_vars.values()) + + +def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger): + + logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........") + + if config.MODEL.RESUME.startswith('https'): + + checkpoint = torch.hub.load_state_dict_from_url( + config.MODEL.RESUME, map_location='cpu', check_hash=True) + + else: + + checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') + + # re-map keys due to name change (only for loading provided models) + rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k] + + for k in rpe_mlp_keys: + + checkpoint['model'][k.replace( + 'rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k) + + msg = model.load_state_dict(checkpoint['model'], strict=False) + + logger.info(msg) + + max_accuracy = 0.0 + + if not config.EVAL_MODE and 'optimizer' in checkpoint \ + and 'lr_scheduler' in checkpoint \ + and 'scaler' in checkpoint \ + and 'epoch' in checkpoint: + + optimizer.load_state_dict(checkpoint['optimizer']) + + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + + scaler.load_state_dict(checkpoint['scaler']) + + config.defrost() + config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 + config.freeze() + + logger.info( + f"=> loaded successfully '{config.MODEL.RESUME}' " + + f"(epoch {checkpoint['epoch']})") + + if 'max_accuracy' in checkpoint: + max_accuracy = checkpoint['max_accuracy'] + + else: + max_accuracy = 0.0 + + del checkpoint + + torch.cuda.empty_cache() + + return max_accuracy + + +def save_checkpoint(config, epoch, model, max_accuracy, + optimizer, lr_scheduler, scaler, logger): + + save_state = {'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'scaler': scaler.state_dict(), + 'max_accuracy': max_accuracy, + 'epoch': epoch, + 'config': config} + + save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') + + logger.info(f"{save_path} saving......") + + torch.save(save_state, save_path) + + logger.info(f"{save_path} saved !!!") + + +def get_grad_norm(parameters, norm_type=2): + + if isinstance(parameters, torch.Tensor): + + parameters = [parameters] + + parameters = list(filter(lambda p: p.grad is not None, parameters)) + + norm_type = float(norm_type) + + total_norm = 0 + + for p in parameters: + + param_norm = p.grad.data.norm(norm_type) + + total_norm += param_norm.item() ** norm_type + + total_norm = total_norm ** (1. / norm_type) + + return total_norm + + +def auto_resume_helper(output_dir, logger): + + checkpoints = os.listdir(output_dir) + + checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] + + logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") + + if len(checkpoints) > 0: + + latest_checkpoint = max([os.path.join(output_dir, d) + for d in checkpoints], key=os.path.getmtime) + + logger.info(f"The latest checkpoint founded: {latest_checkpoint}") + + resume_file = latest_checkpoint + + else: + + resume_file = None + + return resume_file + + +def reduce_tensor(tensor): + + rt = tensor.clone() + + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + + rt /= dist.get_world_size() + + return rt + + +def load_pretrained(config, model, logger): + + logger.info( + f">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........") + + checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') + + checkpoint_model = checkpoint['model'] + + if any([True if 'encoder.' in k else + False for k in checkpoint_model.keys()]): + + checkpoint_model = {k.replace( + 'encoder.', ''): v for k, v in checkpoint_model.items() + if k.startswith('encoder.')} + + logger.info('Detect pre-trained model, remove [encoder.] prefix.') + + else: + + logger.info( + 'Detect non-pre-trained model, pass without doing anything.') + + if config.MODEL.TYPE in ['swin', 'swinv2']: + + logger.info( + ">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") + + checkpoint = remap_pretrained_keys_swin( + model, checkpoint_model, logger) + + else: + + raise NotImplementedError + + msg = model.load_state_dict(checkpoint_model, strict=False) + + logger.info(msg) + + del checkpoint + + torch.cuda.empty_cache() + + logger.info(f">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'") + + +def remap_pretrained_keys_swin(model, checkpoint_model, logger): + + state_dict = model.state_dict() + + # Geometric interpolation when pre-trained patch size mismatch + # with fine-tuned patch size + all_keys = list(checkpoint_model.keys()) + + for key in all_keys: + + if "relative_position_bias_table" in key: + + logger.info(f"Key: {key}") + + rel_position_bias_table_pretrained = checkpoint_model[key] + + rel_position_bias_table_current = state_dict[key] + + L1, nH1 = rel_position_bias_table_pretrained.size() + + L2, nH2 = rel_position_bias_table_current.size() + + if nH1 != nH2: + logger.info(f"Error in loading {key}, passing......") + + else: + + if L1 != L2: + + logger.info( + f"{key}: Interpolate " + + "relative_position_bias_table using geo.") + + src_size = int(L1 ** 0.5) + + dst_size = int(L2 ** 0.5) + + def geometric_progression(a, r, n): + return a * (1.0 - r ** n) / (1.0 - r) + + left, right = 1.01, 1.5 + + while right - left > 1e-6: + + q = (left + right) / 2.0 + + gp = geometric_progression(1, q, src_size // 2) + + if gp > dst_size // 2: + + right = q + + else: + + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + + cur = 1 + + for i in range(src_size // 2): + + dis.append(cur) + + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + + y = r_ids + [0] + dis + + t = dst_size // 2.0 + + dx = np.arange(-t, t + 0.1, 1.0) + + dy = np.arange(-t, t + 0.1, 1.0) + + logger.info("Original positions = %s" % str(x)) + + logger.info("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(nH1): + + z = rel_position_bias_table_pretrained[:, i].view( + src_size, src_size).float().numpy() + + f_cubic = interpolate.interp2d(x, y, z, kind='cubic') + + all_rel_pos_bias_host = \ + torch.Tensor(f_cubic(dx, dy) + ).contiguous().view(-1, 1) + + all_rel_pos_bias.append( + all_rel_pos_bias_host.to( + rel_position_bias_table_pretrained.device)) + + new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + checkpoint_model[key] = new_rel_pos_bias + + # delete relative_position_index since we always re-init it + relative_position_index_keys = [ + k for k in checkpoint_model.keys() if "relative_position_index" in k] + + for k in relative_position_index_keys: + + del checkpoint_model[k] + + # delete relative_coords_table since we always re-init it + relative_coords_table_keys = [ + k for k in checkpoint_model.keys() if "relative_coords_table" in k] + + for k in relative_coords_table_keys: + + del checkpoint_model[k] + + # delete attn_mask since we always re-init it + attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] + + for k in attn_mask_keys: + + del checkpoint_model[k] + + return checkpoint_model + + +def remap_pretrained_keys_vit(model, checkpoint_model, logger): + + # Duplicate shared rel_pos_bias to each layer + if getattr(model, 'use_rel_pos_bias', False) and \ + "rel_pos_bias.relative_position_bias_table" in checkpoint_model: + + logger.info( + "Expand the shared relative position " + + "embedding to each transformer block.") + + num_layers = model.get_num_layers() + + rel_pos_bias = \ + checkpoint_model["rel_pos_bias.relative_position_bias_table"] + + for i in range(num_layers): + + checkpoint_model["blocks.%d.attn.relative_position_bias_table" % + i] = rel_pos_bias.clone() + + checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") + + # Geometric interpolation when pre-trained patch + # size mismatch with fine-tuned patch size + all_keys = list(checkpoint_model.keys()) + + for key in all_keys: + + if "relative_position_index" in key: + + checkpoint_model.pop(key) + + if "relative_position_bias_table" in key: + + rel_pos_bias = checkpoint_model[key] + + src_num_pos, num_attn_heads = rel_pos_bias.size() + + dst_num_pos, _ = model.state_dict()[key].size() + + dst_patch_shape = model.patch_embed.patch_shape + + if dst_patch_shape[0] != dst_patch_shape[1]: + + raise NotImplementedError() + + num_extra_tokens = dst_num_pos - \ + (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) + + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + + dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) + + if src_size != dst_size: + + logger.info("Position interpolate for " + + "%s from %dx%d to %dx%d" % ( + key, + src_size, + src_size, + dst_size, + dst_size)) + + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + + return a * (1.0 - r ** n) / (1.0 - r) + + left, right = 1.01, 1.5 + + while right - left > 1e-6: + + q = (left + right) / 2.0 + + gp = geometric_progression(1, q, src_size // 2) + + if gp > dst_size // 2: + + right = q + + else: + + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + + cur = 1 + + for i in range(src_size // 2): + + dis.append(cur) + + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + + y = r_ids + [0] + dis + + t = dst_size // 2.0 + + dx = np.arange(-t, t + 0.1, 1.0) + + dy = np.arange(-t, t + 0.1, 1.0) + + logger.info("Original positions = %s" % str(x)) + + logger.info("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(num_attn_heads): + + z = rel_pos_bias[:, i].view( + src_size, src_size).float().numpy() + + f = interpolate.interp2d(x, y, z, kind='cubic') + + all_rel_pos_bias_host = \ + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1) + + all_rel_pos_bias.append( + all_rel_pos_bias_host.to(rel_pos_bias.device)) + + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + + new_rel_pos_bias = torch.cat( + (rel_pos_bias, extra_tokens), dim=0) + + checkpoint_model[key] = new_rel_pos_bias + + return checkpoint_model diff --git a/pytorch-caney/pytorch_caney/training/utils.py b/pytorch-caney/pytorch_caney/training/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch-caney/pytorch_caney/utils.py b/pytorch-caney/pytorch_caney/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af65b1124dc00219c24005f32026f26d274ab795 --- /dev/null +++ b/pytorch-caney/pytorch_caney/utils.py @@ -0,0 +1,15 @@ +import torch +import warnings + + +def check_gpus_available(ngpus: int) -> None: + ngpus_available = torch.cuda.device_count() + if ngpus < ngpus_available: + msg = 'Not using all available GPUS.' + \ + f' N GPUs available: {ngpus_available},' + \ + f' N GPUs selected: {ngpus}. ' + warnings.warn(msg) + elif ngpus > ngpus_available: + msg = 'Not enough GPUs to satisfy selected amount' + \ + f': {ngpus}. N GPUs available: {ngpus_available}' + warnings.warn(msg) diff --git a/pytorch-caney/requirements/Dockerfile b/pytorch-caney/requirements/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..5f6eb68cec066e614333abe68a3f5edc4227fea2 --- /dev/null +++ b/pytorch-caney/requirements/Dockerfile @@ -0,0 +1,117 @@ +# Arguments to pass to the image +ARG VERSION_DATE=23.01 +ARG FROM_IMAGE=nvcr.io/nvidia/pytorch + +# Import RAPIDS container as the BASE Image (cuda base image) +FROM ${FROM_IMAGE}:${VERSION_DATE}-py3 + +# Ubuntu needs noninteractive to be forced +ENV DEBIAN_FRONTEND noninteractive +ENV PROJ_LIB="/usr/share/proj" +ENV CPLUS_INCLUDE_PATH="/usr/include/gdal" +ENV C_INCLUDE_PATH="/usr/include/gdal" + +# System dependencies +# System dependencies +RUN apt-get update && \ + apt-get -y install software-properties-common && \ + add-apt-repository ppa:ubuntugis/ubuntugis-unstable && \ + curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \ + apt-get update && apt-get -y dist-upgrade && \ + apt-get -y install build-essential \ + libsm6 \ + libxext6 \ + libxrender-dev \ + libfontconfig1 \ + bzip2 \ + diffutils \ + file \ + build-essential \ + make \ + swig \ + libnetcdf-dev \ + libacl1-dev \ + libgeos++-dev \ + libgeos-dev \ + libsqlite3-dev \ + libx11-dev \ + libproj-dev \ + proj-data \ + proj-bin \ + libspatialindex-dev \ + wget \ + vim \ + curl \ + git \ + procps \ + gcc \ + g++ \ + bzip2 \ + libssl-dev \ + libzmq3-dev \ + libpng-dev \ + libfreetype6-dev \ + locales \ + git-lfs && \ + apt-get -y install gdal-bin libgdal-dev && \ + apt-get -y autoremove && \ + rm -rf /var/cache/apt /var/lib/apt/lists/* + +# Install shiftc +WORKDIR /app +RUN git clone --single-branch --branch master https://github.com/pkolano/shift.git && \ + cd shift/c && \ + make nolustre && \ + cd ../ && \ + install -m 755 perl/shiftc /usr/local/bin/ && \ + install -m 755 c/shift-bin /usr/local/bin/ && \ + install -m 755 perl/shift-mgr /usr/local/bin/ && \ + install -m 644 etc/shiftrc /etc/ && \ + install -m 755 perl/shift-aux /usr/local/bin/ && \ + install -m 755 c/shift-bin /usr/local/bin/ && \ + export LC_ALL=en_US.UTF-8 && \ + export LANG=en_US.UTF-8 && \ + locale-gen en_US.UTF-8 && \ + rm -rf /app + +# Pip +RUN pip --no-cache-dir install omegaconf \ + pytorch-lightning \ + Lightning \ + transformers \ + datasets \ + webdataset \ + 'huggingface_hub[cli,torch]' \ + torchgeo \ + rasterio \ + rioxarray \ + xarray \ + xarray-spatial \ + geopandas \ + opencv-python \ + opencv-python-headless \ + opencv-contrib-python \ + opencv-contrib-python-headless \ + tifffile \ + webcolors \ + Pillow \ + seaborn \ + xgboost \ + tiler \ + segmentation-models \ + timm \ + supervision \ + pytest \ + coveralls \ + rtree \ + sphinx \ + sphinx_rtd_theme \ + yacs \ + termcolor \ + segmentation-models-pytorch \ + pytorch-caney \ + GDAL==`ogrinfo --version | grep -Eo '[0-9]\.[0-9]\.[0-9]+'` + +HEALTHCHECK NONE +ENTRYPOINT [] +CMD ["/bin/bash"] diff --git a/pytorch-caney/requirements/Dockerfile.dev b/pytorch-caney/requirements/Dockerfile.dev new file mode 100644 index 0000000000000000000000000000000000000000..da665931f53d732d6490c8e31849c32ab5f31ede --- /dev/null +++ b/pytorch-caney/requirements/Dockerfile.dev @@ -0,0 +1,116 @@ +# Arguments to pass to the image +ARG VERSION_DATE=23.01 +ARG FROM_IMAGE=nvcr.io/nvidia/pytorch + +# Import RAPIDS container as the BASE Image (cuda base image) +FROM ${FROM_IMAGE}:${VERSION_DATE}-py3 + +# Ubuntu needs noninteractive to be forced +ENV DEBIAN_FRONTEND noninteractive +ENV PROJ_LIB="/usr/share/proj" +ENV CPLUS_INCLUDE_PATH="/usr/include/gdal" +ENV C_INCLUDE_PATH="/usr/include/gdal" + +# System dependencies +# System dependencies +RUN apt-get update && \ + apt-get -y install software-properties-common && \ + add-apt-repository ppa:ubuntugis/ubuntugis-unstable && \ + curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \ + apt-get update && apt-get -y dist-upgrade && \ + apt-get -y install build-essential \ + libsm6 \ + libxext6 \ + libxrender-dev \ + libfontconfig1 \ + bzip2 \ + diffutils \ + file \ + build-essential \ + make \ + swig \ + libnetcdf-dev \ + libacl1-dev \ + libgeos++-dev \ + libgeos-dev \ + libsqlite3-dev \ + libx11-dev \ + libproj-dev \ + proj-data \ + proj-bin \ + libspatialindex-dev \ + wget \ + vim \ + curl \ + git \ + procps \ + gcc \ + g++ \ + bzip2 \ + libssl-dev \ + libzmq3-dev \ + libpng-dev \ + libfreetype6-dev \ + locales \ + git-lfs && \ + apt-get -y install gdal-bin libgdal-dev && \ + apt-get -y autoremove && \ + rm -rf /var/cache/apt /var/lib/apt/lists/* + +# Install shiftc +WORKDIR /app +RUN git clone --single-branch --branch master https://github.com/pkolano/shift.git && \ + cd shift/c && \ + make nolustre && \ + cd ../ && \ + install -m 755 perl/shiftc /usr/local/bin/ && \ + install -m 755 c/shift-bin /usr/local/bin/ && \ + install -m 755 perl/shift-mgr /usr/local/bin/ && \ + install -m 644 etc/shiftrc /etc/ && \ + install -m 755 perl/shift-aux /usr/local/bin/ && \ + install -m 755 c/shift-bin /usr/local/bin/ && \ + export LC_ALL=en_US.UTF-8 && \ + export LANG=en_US.UTF-8 && \ + locale-gen en_US.UTF-8 && \ + rm -rf /app + +# Pip +RUN pip --no-cache-dir install omegaconf \ + pytorch-lightning \ + Lightning \ + transformers \ + datasets \ + webdataset \ + 'huggingface_hub[cli,torch]' \ + torchgeo \ + rasterio \ + rioxarray \ + xarray \ + xarray-spatial \ + geopandas \ + opencv-python \ + opencv-python-headless \ + opencv-contrib-python \ + opencv-contrib-python-headless \ + tifffile \ + webcolors \ + Pillow \ + seaborn \ + xgboost \ + tiler \ + segmentation-models \ + timm \ + supervision \ + pytest \ + coveralls \ + rtree \ + sphinx \ + sphinx_rtd_theme \ + yacs \ + termcolor \ + segmentation-models-pytorch \ + GDAL==`ogrinfo --version | grep -Eo '[0-9]\.[0-9]\.[0-9]+'` + +HEALTHCHECK NONE +ENTRYPOINT [] +CMD ["/bin/bash"] diff --git a/pytorch-caney/requirements/README.md b/pytorch-caney/requirements/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e006d062a8828d17be29b379a605c3ba9c6b5332 --- /dev/null +++ b/pytorch-caney/requirements/README.md @@ -0,0 +1,35 @@ +# Requirements + +pytorch-caney can be installed and used via anaconda environments and containers. +A Docker container is provided and this same container can be converted +to a Singularity container without loosing any functionalities. + +CPU support is limited and the author does not provide any guarantee of usability. + +## Architecture + +The container is built on top of NGC NVIDIA PYTORCH containers. + +This application is powered by PyTorch and PyTorch Lighning AI/ML backends. + +## Example to Download the Container via Singularity + +```bash +module load singularity +singularity build --sandbox pytorch-caney docker://nasanccs/pytorch-caney:latest + +## Example to Install Anaconda Environment + +``` bash +git clone git@github.com:nasa-nccs-hpda/pytorch-caney.git +cd pytorch-caney; conda env create -f requirements/environment_gpu.yml; +conda activate pytorch-caney +``` + +## Container Usage + +As an example, you can shell into the container: + +```bash +singularity shell --nv -B /path/to/container/pytorch-caney +``` diff --git a/pytorch-caney/requirements/environment_gpu.yml b/pytorch-caney/requirements/environment_gpu.yml new file mode 100644 index 0000000000000000000000000000000000000000..696e8662cac5052c073e98b95d14651a68a80156 --- /dev/null +++ b/pytorch-caney/requirements/environment_gpu.yml @@ -0,0 +1,43 @@ +# definition file to build pytorch-caney conda environment +name: pytorch-caney + +channels: + - conda-forge + - rapidsai + - nvidia + +dependencies: + - python=3.9 + - cudatoolkit=11.8 + - pip + - gdal + - pip: + - torch>=2.0.0 + - torchvision>=0.15 + - pytorch-lightning + - omegaconf + - rasterio + - rioxarray + - xarray + - geopandas + - opencv-python + - opencv-python-headless + - opencv-contrib-python + - opencv-contrib-python-headless + - tifffile + - webcolors + - Pillow + - seaborn + - xgboost + - tiler + - segmentation-models + - pytest + - coveralls + - rtree + - sphinx + - sphinx_rtd_theme + - yacs + - termcolor + - numba + - joblib + - segmentation-models-pytorch diff --git a/pytorch-caney/requirements/requirements-test.txt b/pytorch-caney/requirements/requirements-test.txt new file mode 100644 index 0000000000000000000000000000000000000000..462d3e4c61d25e93a5a6501b57c9462823992a00 --- /dev/null +++ b/pytorch-caney/requirements/requirements-test.txt @@ -0,0 +1,28 @@ +torch>=2.0.0 +torchvision>=0.15 +pytorch-lightning +omegaconf +rasterio +rioxarray +xarray +geopandas +opencv-python +opencv-python-headless +opencv-contrib-python +opencv-contrib-python-headless +tifffile +webcolors +Pillow +seaborn +xgboost +tiler +segmentation-models +pytest +coveralls +rtree +sphinx +sphinx_rtd_theme +yacs +termcolor +numba +segmentation-models-pytorch diff --git a/pytorch-caney/requirements/requirements.txt b/pytorch-caney/requirements/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..028ca75d1072884af7773020a19c6bd091e1aef3 --- /dev/null +++ b/pytorch-caney/requirements/requirements.txt @@ -0,0 +1,30 @@ +torch>=2.0.0 +torchvision>=0.15 +pytorch-lightning +omegaconf +rasterio +rioxarray +xarray +geopandas +opencv-python +opencv-python-headless +opencv-contrib-python +opencv-contrib-python-headless +tifffile +webcolors +Pillow +seaborn +xgboost +tiler +segmentation-models +pytest +coveralls +rtree +sphinx +sphinx_rtd_theme +yacs +termcolor +numba +segmentation-models-pytorch +joblib +GDAL>=3.3.0 diff --git a/pytorch-caney/setup.cfg b/pytorch-caney/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..18d44ca75e200913e7f8c4461b5d06e068c86f9a --- /dev/null +++ b/pytorch-caney/setup.cfg @@ -0,0 +1,49 @@ +[metadata] +name = pytorch-caney +version = attr: pytorch_caney.__version__ +description = Methods for pytorch deep learning applications +long_description = file: README.md +long_description_content_type = text/markdown +keywords = pytorch-caney, deep-learning, machine-learning +url = https://github.com/nasa-nccs-hpda/pytorch-caney +author = jordancaraballo +author_email = jordan.a.caraballo-vega@nasa.gov +license = MIT +license_file = LICENSE.md +classifiers = + Development Status :: 4 - Beta + Intended Audience :: Developers + Intended Audience :: Science/Research + Topic :: Software Development :: Libraries :: Python Modules + License :: OSI Approved :: MIT License + Programming Language :: Python :: 3 :: Only +project_urls = + Documentation = https://github.com/nasa-nccs-hpda/pytorch-caney + Source = https://github.com/nasa-nccs-hpda/pytorch-caney + Issues = https://github.com/nasa-nccs-hpda/pytorch-caney/issues + +[options] +packages = find: +zip_safe = True +include_package_data = True +platforms = any +python_requires = >= 3.7 +install_requires = + omegaconf + numpy + pandas + tqdm + xarray + rioxarray + numba + +[options.extras_require] +test = + pytest + coverage[toml] + black +docs = + pdoc==8.0.1 +all = + %(docs)s + %(test)s diff --git a/pytorch-caney/test.sh b/pytorch-caney/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..7c6a496740b434b527e1e7c1ed4d0a3a8c5c57e4 --- /dev/null +++ b/pytorch-caney/test.sh @@ -0,0 +1,22 @@ +set -xeuo pipefail + +# Install deps in a virtual env. +readonly VENV_DIR=/tmp/pytorch-caney-env +# rm -rf "${VENV_DIR}" +python3 -m venv "${VENV_DIR}" +source "${VENV_DIR}/bin/activate" +python --version + +pip install --upgrade pip +pip install --upgrade pip setuptools wheel +pip install flake8 pytype pylint pylint-exit +pip install -r requirements/requirements-test.txt + +flake8 `find pytorch_caney -name '*.py' | xargs` --count --show-source --statistics + +# Run tests using unittest. +python -m unittest discover pytorch_caney/tests + +set +u +deactivate +echo "All tests passed. Congrats!" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a220e171232b25b856a6d8ac5269edfce380a829 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +torch +torchvision +timm +yacs +joblib +numpy \ No newline at end of file diff --git a/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm.config.sav b/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm.config.sav new file mode 100644 index 0000000000000000000000000000000000000000..b98c28539c10e67dc41677c82eec5dfed28c929f --- /dev/null +++ b/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm.config.sav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5472aea9e29c1c2b23adc09389dc2c6ace4960a9926d49dd25a3c876b29e554 +size 2440