|
@@ -11,7 +11,7 @@ from typing import Dict, Union
|
|
|
import torch
|
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
|
-from petals import project_name
|
|
|
+import petals
|
|
|
from petals.bloom.block import BloomBlock
|
|
|
from petals.bloom.model import BloomConfig
|
|
|
from petals.bloom.ops import build_alibi_tensor
|
|
@@ -20,8 +20,8 @@ use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
-DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", project_name, "throughput.json")
|
|
|
-DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, "throughput.lock")
|
|
|
+DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", "petals", "throughput.json")
|
|
|
+DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock")
|
|
|
|
|
|
SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], "cli", "speed_test.py")
|
|
|
|