|
@@ -159,7 +159,9 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
info: Dict[str, Any], *flat_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
|
assert not torch.is_grad_enabled()
|
|
|
num_samples, max_experts = len(experts_per_sample), max(map(len, experts_per_sample))
|
|
|
- flat_inputs_per_sample: List[Tuple[torch.Tensor, ...]] = list(zip(*(x.split(1, dim=0) for x in flat_inputs)))
|
|
|
+
|
|
|
+ flat_inputs_cpu = [tensor.cpu() for tensor in flat_inputs]
|
|
|
+ flat_inputs_per_sample = list(zip(*(x.split(1, dim=0) for x in flat_inputs_cpu)))
|
|
|
assert len(experts_per_sample) == len(flat_inputs_per_sample) == num_samples
|
|
|
|
|
|
# dispatch tasks to all remote experts collect responses
|
|
@@ -167,7 +169,7 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
for i in range(num_samples):
|
|
|
for j, expert in enumerate(experts_per_sample[i]):
|
|
|
input_tensors = [serialize_torch_tensor(tensor, proto.compression) for tensor, proto in zip(
|
|
|
- flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))]
|
|
|
+ flat_inputs_per_sample[i], nested_flatten(info['forward_schema']))]
|
|
|
stub: runtime_grpc.ConnectionHandlerStub = _get_expert_stub(expert.endpoint)
|
|
|
new_task = stub.forward.future(runtime_pb2.ExpertRequest(uid=expert.uid, tensors=input_tensors))
|
|
|
pending_tasks[new_task] = (i, j)
|
|
@@ -182,8 +184,8 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
mask = torch.zeros([num_samples, max_experts], dtype=torch.bool, device=flat_inputs[0].device)
|
|
|
mask[alive_ii, alive_jj] = True
|
|
|
|
|
|
- alive_flat_outputs_stacked = list(map(torch.cat, zip(*alive_flat_outputs)))
|
|
|
- # list of torch tensors, where i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
|
|
|
+ alive_flat_outputs_stacked = (torch.cat(outputs) for outputs in zip(*alive_flat_outputs))
|
|
|
+ # torch tensors, i-th tensor is of shape [num_responded, *expert_outputs[i].shape]
|
|
|
|
|
|
outputs = []
|
|
|
for response_stacked in alive_flat_outputs_stacked:
|
|
@@ -191,10 +193,10 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
[num_samples, max_experts, *response_stacked.shape[1:]], device=response_stacked.device,
|
|
|
dtype=response_stacked.dtype, requires_grad=response_stacked.requires_grad)
|
|
|
output[alive_ii, alive_jj] = response_stacked
|
|
|
- outputs.append(output)
|
|
|
+ outputs.append(output.to(flat_inputs[0].device))
|
|
|
|
|
|
# save individual outputs for backward pass
|
|
|
- ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs)
|
|
|
+ ctx.save_for_backward(alive_ii, alive_jj, *flat_inputs_cpu)
|
|
|
ctx._saved_non_tensors = info, backward_k_min, backward_timeout, timeout_after_k_min, experts_per_sample
|
|
|
return (mask,) + tuple(outputs)
|
|
|
|
|
@@ -203,12 +205,15 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
assert not torch.is_grad_enabled()
|
|
|
info, backward_k_min, backward_timeout, timeout_after_k_min, expert_per_sample = ctx._saved_non_tensors
|
|
|
- alive_ii, alive_jj, *flat_inputs = ctx.saved_tensors
|
|
|
+ alive_ii, alive_jj, *flat_inputs_cpu = ctx.saved_tensors
|
|
|
+
|
|
|
dummy_grad_mask, *flat_grad_outputs = raw_grads
|
|
|
+ flat_grad_outputs_cpu = [tensor.cpu() for tensor in flat_grad_outputs]
|
|
|
+
|
|
|
num_samples, max_experts = dummy_grad_mask.shape
|
|
|
|
|
|
- inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs))
|
|
|
- grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs))
|
|
|
+ inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
|
|
|
+ grad_outputs_per_expert = zip(*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu))
|
|
|
backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
|
|
|
|
|
|
# dispatch tasks to all remote experts, collect responses
|
|
@@ -230,17 +235,19 @@ class _RemoteCallMany(torch.autograd.Function):
|
|
|
|
|
|
# assemble responses
|
|
|
backward_survivor_ii, backward_survivor_jj = map(torch.as_tensor, zip(*backward_survivor_indices) or ([], []))
|
|
|
- survivor_grad_inputs_stacked = list(map(torch.cat, zip(*survivor_grad_inputs)))
|
|
|
- # list of torch tensors, where i-th tensor is of shape [num_backward_survivors, *flat_inputs[i].shape]
|
|
|
+
|
|
|
+ survivor_grad_inputs_stacked = (torch.cat(grad_inputs) for grad_inputs in zip(*survivor_grad_inputs))
|
|
|
+ # torch tensors, i-th tensor is of shape [num_backward_survivors, *flat_inputs_cpu[i].shape]
|
|
|
|
|
|
grad_inputs = []
|
|
|
for i, survivor_grad_stacked in enumerate(survivor_grad_inputs_stacked):
|
|
|
grad_input_per_expert = torch.zeros( # gradient tensor with individual contributions from each expert
|
|
|
- (num_samples, max_experts, *flat_inputs[i].shape[1:]),
|
|
|
+ (num_samples, max_experts, *flat_inputs_cpu[i].shape[1:]),
|
|
|
device=survivor_grad_stacked.device, dtype=survivor_grad_stacked.dtype)
|
|
|
grad_input_per_expert[backward_survivor_ii, backward_survivor_jj] = survivor_grad_stacked
|
|
|
|
|
|
- grad_inputs.append(grad_input_per_expert.sum(dim=1)) # add up gradients from each expert
|
|
|
+ # sum gradients from each expert
|
|
|
+ grad_inputs.append(grad_input_per_expert.to(flat_grad_outputs[0].device).sum(dim=1))
|
|
|
|
|
|
return (DUMMY, None, None, None, None, None, None, None, *grad_inputs)
|
|
|
|