|
@@ -215,7 +215,7 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
""" Like normal backward, but we ignore any experts that failed during backward pass """
|
|
|
expert_logits, alive_ix, alive_expert_probas, *stacked_alive_outputs = ctx.saved_tensors
|
|
|
alive_contexts, k_min, timeout = ctx._alive_contexts, ctx._backward_k_min, ctx._backward_timeout
|
|
|
-
|
|
|
+ print(grad_outputs_flat)
|
|
|
jobs = [partial(cls._run_expert_backward, ctx, prob, *grad_outputs_flat)
|
|
|
for ctx, prob in zip(alive_contexts, alive_expert_probas.split(1))]
|
|
|
results = run_and_await_k(jobs, k=k_min, timeout_after_k=None, timeout_total=timeout)
|
|
@@ -246,7 +246,6 @@ class _RemoteMoECall(torch.autograd.Function):
|
|
|
|
|
|
@staticmethod
|
|
|
def _run_expert_backward(ctx: EmulatedAutogradContext, weight: torch.Tensor, *grad_outputs: torch.Tensor):
|
|
|
- print('!!!', [g.shape for g in grad_outputs])
|
|
|
backward_result = run_isolated_backward(_RemoteModuleCall, ctx, *(grad * weight for grad in grad_outputs))
|
|
|
grad_dummy, no_grad_uid, no_grad_hostname, no_grad_port, *grad_inputs = backward_result
|
|
|
return grad_inputs
|