Parcourir la source

improve type traits

Pavel Samygin il y a 3 ans
Parent
commit
c7e3566803
1 fichiers modifiés avec 2 ajouts et 2 suppressions
  1. 2 2
      src/dht_utils.py

+ 2 - 2
src/dht_utils.py

@@ -78,7 +78,7 @@ def get_remote_module(
     config: src.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
     return_future: bool = False,
-) -> List[src.RemoteTransformerBlock]:
+) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
     """
     :param uid_or_uids: find one or more modules with these ids from across the DHT
     :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
@@ -95,7 +95,7 @@ async def _get_distinct_blocks(
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
     config: src.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
-):
+) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
     single_uid = isinstance(uid_or_uids, ModuleUID)
     uids = [uid_or_uids] if single_uid else uid_or_uids
     p2p = await dht.replicate_p2p()