|
@@ -4,8 +4,8 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
|
|
-import hivemind
|
|
|
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
+from hivemind.dht import DHT
|
|
|
from hivemind.moe.client.balancer import ExpertBalancer
|
|
|
from hivemind.moe.client.expert import DUMMY
|
|
|
from hivemind.proto import runtime_pb2
|
|
@@ -23,7 +23,7 @@ class BalancedRemoteExpert(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
*,
|
|
|
- dht: hivemind.DHT,
|
|
|
+ dht: DHT,
|
|
|
uid_prefix: str,
|
|
|
grid_size: Tuple[int, ...],
|
|
|
forward_timeout: Optional[float] = None,
|
|
@@ -118,13 +118,15 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
|
]
|
|
|
while True:
|
|
|
- try:
|
|
|
- with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
|
|
|
- forward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
|
|
|
- outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
|
|
|
- break
|
|
|
- except BaseException as e:
|
|
|
- logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
+ # try:
|
|
|
+ with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
|
|
|
+ forward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
|
|
|
+ outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
|
|
|
+ break
|
|
|
+ # except KeyboardInterrupt:
|
|
|
+ # break
|
|
|
+ # except BaseException as e:
|
|
|
+ # logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
|
|
|
|
|
|
deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
|
|
|
return tuple(deserialized_outputs)
|
|
@@ -143,7 +145,7 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
|
|
|
try:
|
|
|
with ctx.expert_balancer.use_another_expert(ctx.backward_task_size) as chosen_expert:
|
|
|
backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
|
|
|
- grad_inputs = chosen_expert.stub.forward(backward_request, timeout=ctx.backward_timeout)
|
|
|
+ grad_inputs = chosen_expert.stub.backward(backward_request, timeout=ctx.backward_timeout)
|
|
|
break
|
|
|
except BaseException as e:
|
|
|
logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
|