Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import humanfriendly | |
import numpy as np | |
import torch | |
def get_human_readable_count(number: int) -> str: | |
"""Return human_readable_count | |
Originated from: | |
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py | |
Abbreviates an integer number with K, M, B, T for thousands, millions, | |
billions and trillions, respectively. | |
Examples: | |
>>> get_human_readable_count(123) | |
'123 ' | |
>>> get_human_readable_count(1234) # (one thousand) | |
'1 K' | |
>>> get_human_readable_count(2e6) # (two million) | |
'2 M' | |
>>> get_human_readable_count(3e9) # (three billion) | |
'3 B' | |
>>> get_human_readable_count(4e12) # (four trillion) | |
'4 T' | |
>>> get_human_readable_count(5e15) # (more than trillion) | |
'5,000 T' | |
Args: | |
number: a positive integer number | |
Return: | |
A string formatted according to the pattern described above. | |
""" | |
assert number >= 0 | |
labels = [" ", "K", "M", "B", "T"] | |
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) | |
num_groups = int(np.ceil(num_digits / 3)) | |
num_groups = min(num_groups, len(labels)) | |
shift = -3 * (num_groups - 1) | |
number = number * (10**shift) | |
index = num_groups - 1 | |
return f"{number:.2f} {labels[index]}" | |
def to_bytes(dtype) -> int: | |
return int(str(dtype)[-2:]) // 8 | |
def model_summary(model: torch.nn.Module) -> str: | |
message = "Model structure:\n" | |
message += str(model) | |
tot_params = sum(p.numel() for p in model.parameters()) | |
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params) | |
tot_params = get_human_readable_count(tot_params) | |
num_params = get_human_readable_count(num_params) | |
message += "\n\nModel summary:\n" | |
message += f" Class Name: {model.__class__.__name__}\n" | |
message += f" Total Number of model parameters: {tot_params}\n" | |
message += ( | |
f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n" | |
) | |
num_bytes = humanfriendly.format_size( | |
sum( | |
p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad | |
) | |
) | |
message += f" Size: {num_bytes}\n" | |
dtype = next(iter(model.parameters())).dtype | |
message += f" Type: {dtype}" | |
return message | |