|
@@ -25,6 +25,7 @@ class RemoteExpert(nn.Module):
|
|
|
self._info = None
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
+ """ Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd. """
|
|
|
assert len(kwargs) == len(self.info['keyword_names']), f"Keyword args should be {self.info['keyword_names']}"
|
|
|
kwargs = {key: kwargs[key] for key in self.info['keyword_names']}
|
|
|
# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
|