Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
from functools import partial | |
from pickle import UnpicklingError | |
from typing import Dict, Set, Tuple, Union | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
from flax.core.frozen_dict import FrozenDict, unfreeze | |
from flax.serialization import from_bytes, to_bytes | |
from flax.traverse_util import flatten_dict, unflatten_dict | |
from jax.random import PRNGKey | |
from transformers.configuration_utils import PretrainedConfig | |
from transformers.file_utils import ( | |
FLAX_WEIGHTS_NAME, | |
WEIGHTS_NAME, | |
PushToHubMixin, | |
cached_path, | |
hf_bucket_url, | |
is_offline_mode, | |
is_remote_url, | |
) | |
from .generation_clip_vision_marian_utils import FlaxGenerationMixin | |
from transformers.modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict | |
from transformers.utils import logging | |
logger = logging.get_logger(__name__) | |
def quick_gelu(x): | |
return x * jax.nn.sigmoid(1.702 * x) | |
ACT2FN = { | |
"gelu": partial(nn.gelu, approximate=False), | |
"relu": nn.relu, | |
"silu": nn.swish, | |
"swish": nn.swish, | |
"gelu_new": partial(nn.gelu, approximate=True), | |
"quick_gelu": quick_gelu, | |
} | |
class FlaxCLIPVisionMarianPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): | |
config_class = None | |
base_model_prefix = "" | |
def __init__( | |
self, | |
config: PretrainedConfig, | |
module: nn.Module, | |
input_shape: Tuple = (1, 1), | |
seed: int = 0, | |
dtype: jnp.dtype = jnp.float32, | |
): | |
if config is None: | |
raise ValueError("config cannot be None") | |
if module is None: | |
raise ValueError("module cannot be None") | |
# Those are private to be exposed as typed property on derived classes. | |
self._config = config | |
self._module = module | |
# Those are public as their type is generic to every derived classes. | |
self.key = PRNGKey(seed) | |
self.dtype = dtype | |
# randomly initialized parameters | |
random_params = self.init_weights(self.key, input_shape) | |
# save required_params as set | |
self._required_params = set(flatten_dict(unfreeze(random_params)).keys()) | |
self.params = random_params | |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict: | |
raise NotImplementedError(f"init method has to be implemented for {self}") | |
def _from_config(cls, config, **kwargs): | |
""" | |
All context managers that the model should be initialized under go here. | |
""" | |
return cls(config, **kwargs) | |
def config(self) -> PretrainedConfig: | |
return self._config | |
def module(self) -> nn.Module: | |
return self._module | |
def params(self) -> Union[Dict, FrozenDict]: | |
return self._params | |
def required_params(self) -> Set: | |
return self._required_params | |
def params(self, params: Union[Dict, FrozenDict]): | |
if isinstance(params, FrozenDict): | |
params = unfreeze(params) | |
param_keys = set(flatten_dict(params).keys()) | |
if len(self.required_params - param_keys) > 0: | |
raise ValueError( | |
"Some parameters are missing. Make sure that `params` include the following " | |
f"parameters {self.required_params - param_keys}" | |
) | |
self._params = params | |
def from_pretrained( | |
cls, | |
pretrained_model_name_or_path: Union[str, os.PathLike], | |
dtype: jnp.dtype = jnp.float32, | |
*model_args, | |
**kwargs | |
): | |
config = kwargs.pop("config", None) | |
cache_dir = kwargs.pop("cache_dir", None) | |
from_pt = kwargs.pop("from_pt", False) | |
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) | |
force_download = kwargs.pop("force_download", False) | |
resume_download = kwargs.pop("resume_download", False) | |
proxies = kwargs.pop("proxies", None) | |
local_files_only = kwargs.pop("local_files_only", False) | |
use_auth_token = kwargs.pop("use_auth_token", None) | |
revision = kwargs.pop("revision", None) | |
from_pipeline = kwargs.pop("_from_pipeline", None) | |
from_auto_class = kwargs.pop("_from_auto", False) | |
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} | |
if from_pipeline is not None: | |
user_agent["using_pipeline"] = from_pipeline | |
if is_offline_mode() and not local_files_only: | |
logger.info("Offline mode: forcing local_files_only=True") | |
local_files_only = True | |
# Load config if we don't provide a configuration | |
if not isinstance(config, PretrainedConfig): | |
config_path = config if config is not None else pretrained_model_name_or_path | |
config, model_kwargs = cls.config_class.from_pretrained( | |
config_path, | |
*model_args, | |
cache_dir=cache_dir, | |
return_unused_kwargs=True, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
use_auth_token=use_auth_token, | |
revision=revision, | |
_from_auto=from_auto_class, | |
_from_pipeline=from_pipeline, | |
**kwargs, | |
) | |
else: | |
model_kwargs = kwargs | |
# Add the dtype to model_kwargs | |
model_kwargs["dtype"] = dtype | |
# Load model | |
if pretrained_model_name_or_path is not None: | |
if os.path.isdir(pretrained_model_name_or_path): | |
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): | |
# Load from a PyTorch checkpoint | |
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) | |
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): | |
# Load from a Flax checkpoint | |
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) | |
else: | |
raise EnvironmentError( | |
f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory " | |
f"{pretrained_model_name_or_path} or `from_pt` set to False" | |
) | |
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): | |
archive_file = pretrained_model_name_or_path | |
else: | |
archive_file = hf_bucket_url( | |
pretrained_model_name_or_path, | |
filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME, | |
revision=revision, | |
) | |
# redirect to the cache, if necessary | |
try: | |
resolved_archive_file = cached_path( | |
archive_file, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
resume_download=resume_download, | |
local_files_only=local_files_only, | |
use_auth_token=use_auth_token, | |
user_agent=user_agent, | |
) | |
except EnvironmentError as err: | |
logger.error(err) | |
msg = ( | |
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" | |
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" | |
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n" | |
) | |
raise EnvironmentError(msg) | |
if resolved_archive_file == archive_file: | |
logger.info(f"loading weights file {archive_file}") | |
else: | |
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}") | |
else: | |
resolved_archive_file = None | |
# init random models | |
model = cls(config, *model_args, **model_kwargs) | |
if from_pt: | |
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file) | |
else: | |
with open(resolved_archive_file, "rb") as state_f: | |
try: | |
state = from_bytes(cls, state_f.read()) | |
except UnpicklingError: | |
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ") | |
# make sure all arrays are stored as jnp.arrays | |
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: | |
# https://github.com/google/flax/issues/1261 | |
state = jax.tree_util.tree_map(jnp.array, state) | |
# if model is base model only use model_prefix key | |
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state: | |
state = state[cls.base_model_prefix] | |
# if model is head model and we are loading weights from base model | |
# we initialize new params dict with base_model_prefix | |
if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state: | |
state = {cls.base_model_prefix: state} | |
# flatten dicts | |
state = flatten_dict(state) | |
random_state = flatten_dict(unfreeze(model.params)) | |
missing_keys = model.required_params - set(state.keys()) | |
unexpected_keys = set(state.keys()) - model.required_params | |
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not | |
# matching the weights in the model. | |
mismatched_keys = [] | |
for key in state.keys(): | |
if key in random_state and state[key].shape != random_state[key].shape: | |
if ignore_mismatched_sizes: | |
mismatched_keys.append((key, state[key].shape, random_state[key].shape)) | |
state[key] = random_state[key] | |
else: | |
raise ValueError( | |
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " | |
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " | |
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " | |
"model." | |
) | |
# add missing keys as random parameters | |
for missing_key in missing_keys: | |
state[missing_key] = random_state[missing_key] | |
# remove unexpected keys to not be saved again | |
for unexpected_key in unexpected_keys: | |
del state[unexpected_key] | |
if len(unexpected_keys) > 0: | |
logger.warning( | |
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " | |
f"initializing {model.__class__.__name__}: {unexpected_keys}\n" | |
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " | |
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" | |
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " | |
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." | |
) | |
else: | |
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") | |
if len(missing_keys) > 0: | |
logger.warning( | |
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " | |
f"and are newly initialized: {missing_keys}\n" | |
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." | |
) | |
elif len(mismatched_keys) == 0: | |
logger.info( | |
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" | |
f"If your task is similar to the task the model of the checkpoint was trained on, " | |
f"you can already use {model.__class__.__name__} for predictions without further training." | |
) | |
if len(mismatched_keys) > 0: | |
mismatched_warning = "\n".join( | |
[ | |
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" | |
for key, shape1, shape2 in mismatched_keys | |
] | |
) | |
logger.warning( | |
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " | |
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n" | |
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." | |
) | |
# set correct parameters | |
model.params = unflatten_dict(state) | |
return model | |
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs): | |
""" | |
Save a model and its configuration file to a directory, so that it can be re-loaded using the | |
`:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method | |
Arguments: | |
save_directory (:obj:`str` or :obj:`os.PathLike`): | |
Directory to which to save. Will be created if it doesn't exist. | |
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not to push your model to the Hugging Face model hub after saving it. | |
.. warning:: | |
Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with | |
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are | |
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory | |
instead. | |
kwargs: | |
Additional key word arguments passed along to the | |
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method. | |
""" | |
if os.path.isfile(save_directory): | |
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") | |
return | |
if push_to_hub: | |
commit_message = kwargs.pop("commit_message", None) | |
repo = self._create_or_get_repo(save_directory, **kwargs) | |
os.makedirs(save_directory, exist_ok=True) | |
# get abs dir | |
save_directory = os.path.abspath(save_directory) | |
# save config as well | |
self.config.architectures = [self.__class__.__name__[4:]] | |
self.config.save_pretrained(save_directory) | |
# save model | |
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) | |
with open(output_model_file, "wb") as f: | |
params = params if params is not None else self.params | |
model_bytes = to_bytes(params) | |
f.write(model_bytes) | |
logger.info(f"Model weights saved in {output_model_file}") | |
if push_to_hub: | |
url = self._push_to_hub(repo, commit_message=commit_message) | |
logger.info(f"Model pushed to the hub in this commit: {url}") | |