|
@@ -9,9 +9,8 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
|
-from petals.client.spending_policy import NoSpendingPolicy
|
|
|
+import petals.dht_utils
|
|
|
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
|
|
|
-from petals.dht_utils import get_remote_module_infos
|
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
@@ -88,7 +87,9 @@ class RemoteSequenceManager:
|
|
|
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
|
|
|
|
|
|
def update_block_infos_(self):
|
|
|
- new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
|
|
|
+ new_block_infos = petals.dht_utils.get_remote_module_infos(
|
|
|
+ self.dht, self.block_uids, expiration_time=float("inf")
|
|
|
+ )
|
|
|
assert len(new_block_infos) == len(self.block_uids)
|
|
|
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
|
|
|
if info is None:
|