|
@@ -9,7 +9,7 @@ import torch.nn as nn
|
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
|
|
from .expert import RemoteExpert, _RemoteModuleCall
|
|
|
-from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY
|
|
|
+from ..utils import nested_map, check_numpy, run_and_await_k, nested_pack, nested_flatten, DUMMY, run_in_background
|
|
|
from ..utils import run_isolated_forward, EmulatedAutogradContext, run_isolated_backward, map_with_parallel_backward
|
|
|
|
|
|
|
|
@@ -33,17 +33,20 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
:param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
|
|
|
Any expert that didn't manage to return output after that delay is considered unavailable
|
|
|
:param expert_padding: internal value used to denote "absent expert". Should not coincide with any expert uid.
|
|
|
+ :param allow_broadcasting: if RemoteMixtureOfExperts if fed with input dimension above 2,
|
|
|
+ allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
|
|
|
+ allow_broadcasting=False will raise an error
|
|
|
"""
|
|
|
def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None,
|
|
|
k_best, k_min=1, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
|
|
|
- uid_prefix='', expert_padding=None):
|
|
|
+ uid_prefix='', expert_padding=None, allow_broadcasting=True):
|
|
|
super().__init__()
|
|
|
self.network, self.grid_size = network, grid_size
|
|
|
self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
|
|
|
self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
|
|
|
self.timeout_after_k_min, self.backward_timeout = timeout_after_k_min, backward_timeout
|
|
|
+ self.allow_broadcasting = allow_broadcasting
|
|
|
|
|
|
- self.thread_pool = mp.pool.ThreadPool(num_workers or k_best * 2)
|
|
|
self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions
|
|
|
|
|
|
# grab some expert to set ensemble output shape
|
|
@@ -121,13 +124,13 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
# select k best candidates according to scores but only those that are still active
|
|
|
new_order = np.argsort(- new_scores, axis=-1)
|
|
|
top_alive_lookups = [
|
|
|
- self.thread_pool.apply_async(self.network.first_k_active, args=(cands[order], k_best), kwds=kwargs)
|
|
|
+ run_in_background(self.network.first_k_active, cands[order], k_best, **kwargs)
|
|
|
for cands, order in zip(new_candidates, new_order)]
|
|
|
|
|
|
batch_cand_to_score = [
|
|
|
dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)]
|
|
|
|
|
|
- top_alive_prefixes = [result.get() for result in top_alive_lookups]
|
|
|
+ top_alive_prefixes = [result.result() for result in top_alive_lookups]
|
|
|
top_alive_scores = [list(map(cand_to_score.get, top_cands))
|
|
|
for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)]
|
|
|
|