Quellcode durchsuchen

fix model creation

justheuristic vor 3 Jahren
Ursprung
Commit
7d68f6b9a4
3 geänderte Dateien mit 4 neuen und 2 gelöschten Zeilen
  1. 1 0
      src/client/remote_model.py
  2. 1 1
      src/client/remote_sequential.py
  3. 2 1
      src/dht_utils.py

+ 1 - 0
src/client/remote_model.py

@@ -6,6 +6,7 @@ from src.bloom import DistributedBloomConfig, BloomForCausalLM
 from src.client.remote_sequential import RemoteSequential
 from src.data_structures import UID_DELIMITER
 
+
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 

+ 1 - 1
src/client/remote_sequential.py

@@ -37,7 +37,7 @@ class RemoteSequential(nn.Sequential):
         for uid, info in zip(self.block_uids, self.block_infos):
             assert isinstance(info, (type(None), RemoteModuleInfo)), f"Unexpected dht entry for {uid}: {info}"
             assert info is not None, f"Found no active peers for block {uid}"
-            assert isinstance(info.peer_ids, (list, tuple)), f"expected peer_ids to be list/tuple, got {info.peer_ids}"
+            assert isinstance(info.peer_ids, set), f"expected peer_ids to be a set, got {info.peer_ids}"
             assert info.uid == uid, f"The DHT entry for {uid} actually points to {info.uid}"
             assert len(info.peer_ids) > 0,  f"Found no active peers for block {uid}"
 

+ 2 - 1
src/dht_utils.py

@@ -106,7 +106,8 @@ async def _get_remote_module_infos(
     for i, uid in enumerate(uids):
         metadata = found[uid]
         if metadata is None or not isinstance(metadata.value, dict):
-            logger.error(f"Incorrect metadata for {uid}: {metadata}")
+            if metadata is not None:
+                logger.error(f"Incorrect metadata for {uid}: {metadata}")
             continue
         valid_entries = set()
         for maybe_peer_id, _unused_value in metadata.value.items():