|
@@ -37,14 +37,14 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
|
|
|
allow_broadcasting=False will raise an error
|
|
|
"""
|
|
|
- def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
|
|
|
- k_best, k_min=1, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
|
|
|
+ def __init__(self, *, in_features, grid_size: Tuple[int], network, k_best, k_min=1,
|
|
|
+ forward_timeout=None, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
|
|
|
uid_prefix='', expert_padding=None, allow_broadcasting=True):
|
|
|
super().__init__()
|
|
|
self.network, self.grid_size = network, grid_size
|
|
|
self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
|
|
|
self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
|
|
|
- self.timeout_after_k_min, self.backward_timeout = timeout_after_k_min, backward_timeout
|
|
|
+ self.forward_timeout, self.timeout_after_k_min, self.backward_timeout = forward_timeout, timeout_after_k_min, backward_timeout
|
|
|
self.allow_broadcasting = allow_broadcasting
|
|
|
|
|
|
self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
|