|
@@ -3,7 +3,7 @@ from concurrent.futures import Future
|
|
|
from dataclasses import dataclass
|
|
|
from queue import Queue
|
|
|
from threading import Thread
|
|
|
-from typing import Any, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple
|
|
|
+from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
@@ -99,7 +99,7 @@ class RemoteExpertWorker:
|
|
|
|
|
|
_task_queue: Queue = Queue()
|
|
|
_event_thread: Optional[Thread] = None
|
|
|
- _pid: int = 0
|
|
|
+ _pid: int = -1
|
|
|
|
|
|
@classmethod
|
|
|
def _run(cls):
|
|
@@ -113,7 +113,8 @@ class RemoteExpertWorker:
|
|
|
except Exception as e:
|
|
|
future.set_exception(e)
|
|
|
continue
|
|
|
- future.set_result(result)
|
|
|
+ if not future.cancelled():
|
|
|
+ future.set_result(result)
|
|
|
|
|
|
loop.run_until_complete(receive_tasks())
|
|
|
|
|
@@ -151,7 +152,7 @@ class RemoteExpertWorker:
|
|
|
@classmethod
|
|
|
def spawn_experts_future(
|
|
|
cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
|
|
|
- ) -> MPFuture[List[Optional[RemoteExpert]]]:
|
|
|
+ ) -> Future[List[Optional[RemoteExpert]]]:
|
|
|
async def _unpack():
|
|
|
p2p = cls.run_coroutine(dht.replicate_p2p(), True)
|
|
|
return cls.spawn_experts(await infos, await p2p)
|
|
@@ -166,7 +167,7 @@ class RemoteExpertWorker:
|
|
|
|
|
|
@classmethod
|
|
|
def spawn_experts_bulk_future(
|
|
|
- cls, infos: MPFuture[Sequence[Sequence[Optional[RemoteExpertInfo]]]], dht: DHT
|
|
|
+ cls, infos: Future[Sequence[Sequence[Optional[RemoteExpertInfo]]]], dht: DHT
|
|
|
) -> MPFuture[List[List[Optional[RemoteExpert]]]]:
|
|
|
async def _unpack():
|
|
|
return cls.spawn_experts_bulk(await infos, dht)
|
|
@@ -174,6 +175,75 @@ class RemoteExpertWorker:
|
|
|
return cls.run_coroutine(_unpack, True)
|
|
|
|
|
|
|
|
|
+async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
+ split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
+
|
|
|
+ grad_inputs = await stub.rpc_backward_stream(
|
|
|
+ amap_in_executor(
|
|
|
+ lambda t: runtime_pb2.ExpertRequest(uid=uid, tensors=[t]),
|
|
|
+ as_aiter(*split),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ return await gather_from_rpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
+
|
|
|
+
|
|
|
+async def _backward(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
+ grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
|
|
|
+ runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
|
|
|
+ )
|
|
|
+ return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|
|
|
+
|
|
|
+
|
|
|
+async def expert_backward(
|
|
|
+ uid: str, inputs_and_grads: Sequence[torch.Tensor], compressions: Iterable, stub
|
|
|
+) -> List[torch.Tensor]:
|
|
|
+ serialized_tensors = (
|
|
|
+ serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs_and_grads, compressions)
|
|
|
+ )
|
|
|
+
|
|
|
+ size = 0
|
|
|
+ for t in inputs_and_grads:
|
|
|
+ size += t.element_size() * t.nelement()
|
|
|
+ if size >= DEFAULT_MAX_MSG_SIZE:
|
|
|
+ return await _backward_stream(uid, serialized_tensors, stub)
|
|
|
+ else:
|
|
|
+ return await _backward(uid, serialized_tensors, stub)
|
|
|
+
|
|
|
+
|
|
|
+async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
+ split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
+
|
|
|
+ outputs = await stub.rpc_forward_stream(
|
|
|
+ amap_in_executor(
|
|
|
+ lambda t: runtime_pb2.ExpertRequest(uid=uid, tensors=[t]),
|
|
|
+ as_aiter(*split),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ return await gather_from_rpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
+
|
|
|
+
|
|
|
+async def _forward(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
+ outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
|
|
|
+ runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
|
|
|
+ )
|
|
|
+ return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
|
|
+
|
|
|
+
|
|
|
+async def expert_forward(uid: str, inputs: Sequence[torch.Tensor], compressions: Iterable, stub) -> List[torch.Tensor]:
|
|
|
+ serialized_tensors = (
|
|
|
+ serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs, compressions)
|
|
|
+ )
|
|
|
+ size = 0
|
|
|
+ for t in inputs:
|
|
|
+ size += t.element_size() * t.nelement()
|
|
|
+ if size >= DEFAULT_MAX_MSG_SIZE:
|
|
|
+ return await _forward_stream(uid, serialized_tensors, stub)
|
|
|
+ else:
|
|
|
+ return await _forward(uid, serialized_tensors, stub)
|
|
|
+
|
|
|
+
|
|
|
class _RemoteModuleCall(torch.autograd.Function):
|
|
|
"""Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
|
|
|
|
|
@@ -193,93 +263,21 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
ctx.uid, ctx.stub, ctx.info = uid, stub, info
|
|
|
ctx.save_for_backward(*inputs)
|
|
|
|
|
|
- serialized_tensors = (
|
|
|
- serialize_torch_tensor(inp, proto.compression)
|
|
|
- for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
|
+ deserialized_outputs = RemoteExpertWorker.run_coroutine(
|
|
|
+ expert_forward(uid, inputs, (p.compression for p in nested_flatten(info["forward_schema"])), stub)
|
|
|
)
|
|
|
|
|
|
- size = 0
|
|
|
- for t in inputs:
|
|
|
- size += t.element_size() * t.nelement()
|
|
|
- if size >= DEFAULT_MAX_MSG_SIZE:
|
|
|
- deserialized_outputs = cls.forward_stream(serialized_tensors, ctx, stub)
|
|
|
- break
|
|
|
- else:
|
|
|
- deserialized_outputs = cls.forward_oneshot(serialized_tensors, ctx, stub)
|
|
|
-
|
|
|
return tuple(deserialized_outputs)
|
|
|
|
|
|
- @classmethod
|
|
|
- def forward_stream(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
|
|
|
- split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
-
|
|
|
- outputs = RemoteExpertWorker.run_coroutine(
|
|
|
- stub.rpc_forward_stream(
|
|
|
- amap_in_executor(
|
|
|
- lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t]),
|
|
|
- as_aiter(*split),
|
|
|
- ),
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- return RemoteExpertWorker.run_coroutine(
|
|
|
- gather_from_rpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
- )
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def forward_oneshot(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
|
|
|
-
|
|
|
- outputs = RemoteExpertWorker.run_coroutine(
|
|
|
- stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=list(serialized_tensors)))
|
|
|
- )
|
|
|
-
|
|
|
- return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
|
|
-
|
|
|
@classmethod
|
|
|
@once_differentiable
|
|
|
def backward(cls, ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
|
|
|
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
|
inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
|
|
|
backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
|
|
|
- serialized_tensors = (
|
|
|
- serialize_torch_tensor(tensor, proto.compression)
|
|
|
- for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
- )
|
|
|
-
|
|
|
- size = 0
|
|
|
- for t in inputs_and_grad_outputs:
|
|
|
- size += t.element_size() * t.nelement()
|
|
|
- if size >= DEFAULT_MAX_MSG_SIZE:
|
|
|
- deserialized_grad_inputs = cls.backward_stream(serialized_tensors, ctx)
|
|
|
- break
|
|
|
- else:
|
|
|
- deserialized_grad_inputs = cls.backward_oneshot(serialized_tensors, ctx)
|
|
|
-
|
|
|
- return (DUMMY, None, None, None, *deserialized_grad_inputs)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- @once_differentiable
|
|
|
- def backward_stream(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
|
|
|
- split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
-
|
|
|
- grad_inputs = RemoteExpertWorker.run_coroutine(
|
|
|
- ctx.stub.rpc_backward_stream(
|
|
|
- amap_in_executor(
|
|
|
- lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t]),
|
|
|
- as_aiter(*split),
|
|
|
- ),
|
|
|
- )
|
|
|
- )
|
|
|
|
|
|
- return RemoteExpertWorker.run_coroutine(
|
|
|
- gather_from_rpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
+ deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
|
|
|
+ expert_backward(ctx.uid, inputs_and_grad_outputs, (p.compression for p in backward_schema), ctx.stub)
|
|
|
)
|
|
|
|
|
|
- @classmethod
|
|
|
- @once_differentiable
|
|
|
- def backward_oneshot(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
|
|
|
- grad_inputs = RemoteExpertWorker.run_coroutine(
|
|
|
- ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=list(serialized_tensors)))
|
|
|
- )
|
|
|
-
|
|
|
- return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|
|
|
+ return (DUMMY, None, None, None, *deserialized_grad_inputs)
|