|
@@ -26,8 +26,8 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
:param grid_size: tesseract dimensions that form expert uid (see below)
|
|
|
:param uid_prefix: common prefix for all expert uids
|
|
|
expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
|
|
|
- :param network: TesseractNetwork where the experts reside
|
|
|
- :param num_workers: number of threads for parallel network operation
|
|
|
+ :param dht: DHTNode where the experts reside
|
|
|
+ :param num_workers: number of threads for parallel dht operation
|
|
|
:param k_best: queries this many experts with highest scores
|
|
|
:param k_min: makes sure at least this many experts returned output
|
|
|
:param timeout_after_k_min: waits for this many seconds after k_min experts returned results.
|
|
@@ -37,11 +37,11 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
allow_broadcasting=True will flatten first d-1 input dimensions, apply RemoteMixtureOfExperts and un-flatten again
|
|
|
allow_broadcasting=False will raise an error
|
|
|
"""
|
|
|
- def __init__(self, *, in_features, grid_size: Tuple[int], network, k_best, k_min=1,
|
|
|
+ def __init__(self, *, in_features, grid_size: Tuple[int], dht, k_best, k_min=1,
|
|
|
forward_timeout=None, timeout_after_k_min=1.0, backward_k_min=1, backward_timeout=None,
|
|
|
uid_prefix='', expert_padding=None, allow_broadcasting=True):
|
|
|
super().__init__()
|
|
|
- self.network, self.grid_size = network, grid_size
|
|
|
+ self.dht, self.grid_size = dht, grid_size
|
|
|
self.uid_prefix, self.expert_padding = uid_prefix, expert_padding
|
|
|
self.k_best, self.k_min, self.backward_k_min = k_best, k_min, backward_k_min
|
|
|
self.forward_timeout, self.timeout_after_k_min, self.backward_timeout = forward_timeout, timeout_after_k_min, backward_timeout
|
|
@@ -94,7 +94,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
:param grid_scores: scores predicted for each dimension in the grid,
|
|
|
:type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]]
|
|
|
:param k_best: how many of the top experts participate in the computation
|
|
|
- :param kwargs: extra keyword parameters passed to self.network.first_k_active
|
|
|
+ :param kwargs: extra keyword parameters passed to self.dht.first_k_active
|
|
|
:returns: a list of *batch_size* lists that contain chosen experts for one sample each inner list contains \
|
|
|
RemoteExpert instances for *up to* k_best experts
|
|
|
"""
|
|
@@ -104,7 +104,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
beam = np.array([[self.uid_prefix]] * batch_size, dtype=object) # [batch_size, up_to_beam_size]
|
|
|
scores = np.zeros([batch_size, 1], dtype=np.float64)
|
|
|
|
|
|
- delimeters = np.array(self.network.UID_DELIMETER)[None, None, None] # pre-compute numpy array for fast concat
|
|
|
+ delimeters = np.array(self.dht.UID_DELIMETER)[None, None, None] # pre-compute numpy array for fast concat
|
|
|
|
|
|
for dim_index, dim_scores in enumerate(grid_scores):
|
|
|
dim_scores = check_numpy(dim_scores)
|
|
@@ -121,7 +121,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
# select k best candidates according to scores but only those that are still active
|
|
|
new_order = np.argsort(- new_scores, axis=-1)
|
|
|
top_alive_lookups = [
|
|
|
- run_in_background(self.network.first_k_active, cands[order], k_best, **kwargs)
|
|
|
+ run_in_background(self.dht.first_k_active, cands[order], k_best, **kwargs)
|
|
|
for cands, order in zip(new_candidates, new_order)]
|
|
|
|
|
|
batch_cand_to_score = [
|
|
@@ -137,7 +137,7 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
scores = np.array([row + [-float('inf')] * (k_best - len(row))
|
|
|
for row in top_alive_scores], dtype='float32')
|
|
|
|
|
|
- unique_experts = self.network.get_experts(list(set(
|
|
|
+ unique_experts = self.dht.get_experts(list(set(
|
|
|
uid for row in beam for uid in row if uid != self.expert_padding)))
|
|
|
if self._outputs_schema is None:
|
|
|
self._outputs_schema = next(iter(unique_experts)).info['outputs_schema']
|
|
@@ -160,8 +160,8 @@ class RemoteMixtureOfExperts(nn.Module):
|
|
|
|
|
|
grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
|
|
|
for i, expert in enumerate(flat_experts):
|
|
|
- expert_indices = expert.uid[len(self.uid_prefix) + len(self.network.UID_DELIMETER):]
|
|
|
- expert_indices = list(map(int, expert_indices.split(self.network.UID_DELIMETER)))
|
|
|
+ expert_indices = expert.uid[len(self.uid_prefix) + len(self.dht.UID_DELIMETER):]
|
|
|
+ expert_indices = list(map(int, expert_indices.split(self.dht.UID_DELIMETER)))
|
|
|
grid_indices[i] = expert_indices
|
|
|
|
|
|
scores_per_dim = [
|