|
import warnings |
|
warnings.filterwarnings('ignore', category=UserWarning, module='bitsandbytes') |
|
import logging |
|
from .logging_utils import SpecificWarningFilter |
|
hf_dynamic_modules_logger = logging.getLogger('transformers.dynamic_module_utils') |
|
new_files_warning_filter = SpecificWarningFilter('A new version of the following files was downloaded from') |
|
hf_dynamic_modules_logger.addFilter(new_files_warning_filter) |
|
from . import algorithms, callbacks, loggers, optim, registry, utils |
|
from .data import ConcatTokensDataset, NoConcatDataset, Seq2SeqFinetuningCollator, build_finetuning_dataloader |
|
from .hf import ComposerHFCausalLM, ComposerHFT5 |
|
from .attention import MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, scaled_multihead_dot_product_attention |
|
from .blocks import MPTBlock |
|
from .ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn |
|
from .mpt import ComposerMPTCausalLM, MPTConfig, MPTForCausalLM, MPTModel, MPTPreTrainedModel |
|
from .tokenizers import TiktokenTokenizerWrapper |
|
__version__ = '0.7.0' |