|
@@ -5,7 +5,7 @@ import psutil
|
|
import torch.backends.quantized
|
|
import torch.backends.quantized
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
import transformers
|
|
import transformers
|
|
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
|
+from hivemind.utils.logging import get_logger
|
|
from huggingface_hub import Repository
|
|
from huggingface_hub import Repository
|
|
from tqdm.auto import tqdm
|
|
from tqdm.auto import tqdm
|
|
from transformers.models.bloom.modeling_bloom import BloomModel
|
|
from transformers.models.bloom.modeling_bloom import BloomModel
|
|
@@ -13,7 +13,6 @@ from transformers.models.bloom.modeling_bloom import BloomModel
|
|
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
|
|
from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
|
|
from petals.client import DistributedBloomConfig
|
|
from petals.client import DistributedBloomConfig
|
|
|
|
|
|
-use_hivemind_log_handler("in_root_logger")
|
|
|
|
logger = get_logger(__file__)
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
|
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|