|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import torch |
|
|
|
from cpufeature import CPUFeature |
|
from petals.constants import PUBLIC_INITIAL_PEERS |
|
|
|
|
|
@dataclass |
|
class ModelInfo: |
|
repo: str |
|
adapter: Optional[str] = None |
|
|
|
|
|
MODELS = [ |
|
ModelInfo(repo="enoch/llama-65b-hf"), |
|
ModelInfo(repo="enoch/llama-65b-hf", adapter="timdettmers/guanaco-65b"), |
|
|
|
] |
|
DEFAULT_MODEL_NAME = "enoch/llama-65b-hf" |
|
|
|
INITIAL_PEERS = PUBLIC_INITIAL_PEERS |
|
|
|
|
|
|
|
DEVICE = "cpu" |
|
|
|
if DEVICE == "cuda": |
|
TORCH_DTYPE = "auto" |
|
elif CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"]: |
|
TORCH_DTYPE = torch.bfloat16 |
|
else: |
|
TORCH_DTYPE = torch.float32 |
|
|
|
STEP_TIMEOUT = 5 * 60 |
|
MAX_SESSIONS = 50 |
|
|