|
@@ -78,7 +78,7 @@ def get_remote_module(
|
|
|
config: src.DistributedBloomConfig,
|
|
|
dht_prefix: Optional[str] = None,
|
|
|
return_future: bool = False,
|
|
|
-) -> List[src.RemoteTransformerBlock]:
|
|
|
+) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
|
|
|
"""
|
|
|
:param uid_or_uids: find one or more modules with these ids from across the DHT
|
|
|
:param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
|
|
@@ -95,7 +95,7 @@ async def _get_distinct_blocks(
|
|
|
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
|
|
config: src.DistributedBloomConfig,
|
|
|
dht_prefix: Optional[str] = None,
|
|
|
-):
|
|
|
+) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
|
|
|
single_uid = isinstance(uid_or_uids, ModuleUID)
|
|
|
uids = [uid_or_uids] if single_uid else uid_or_uids
|
|
|
p2p = await dht.replicate_p2p()
|