Browse Source

check for none

justheuristic 3 years ago
parent
commit
7903bd8f9f
2 changed files with 8 additions and 1 deletions
  1. 5 0
      src/client/remote_block.py
  2. 3 1
      src/client/remote_sequential.py

+ 5 - 0
src/client/remote_block.py

@@ -29,6 +29,11 @@ class RemoteTransformerBlock(RemoteExpert):
     def stub(self) -> StubBase:
         return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
 
+    def forward(self, inputs: torch.Tensor, **kwargs):
+        for k, v in kwargs.items():
+            assert v is None, f"Extra keyword arguments are not yet supported (got {k} = {v})"
+        return super().forward(inputs)
+
     def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
         """Initialize a new inference session with the specified remote server"""
         _ = self.info  # create _info manually since the built-in property will not work inside RemoteExpertWorker

+ 3 - 1
src/client/remote_sequential.py

@@ -16,7 +16,9 @@ logger = get_logger(__file__)
 
 
 class RemoteSequential(nn.Sequential):
-    """A sequence of transformer blocks hosted by the swarm"""
+    """
+    A sequence of transformer blocks hosted by the swarm.
+    """
 
     def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: Optional[str] = None, max_retries: int = 3):
         logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")