123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287 |
- from __future__ import annotations
- import os
- from concurrent.futures import Future
- from dataclasses import dataclass
- from queue import Queue
- from threading import Thread
- from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
- import torch
- import torch.nn as nn
- from torch.autograd.function import once_differentiable
- import hivemind
- from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
- from hivemind.dht import DHT
- from hivemind.p2p import P2P, PeerInfo, StubBase
- from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
- from hivemind.proto import runtime_pb2
- from hivemind.utils import (
- MSGPackSerializer,
- amap_in_executor,
- iter_as_aiter,
- nested_compare,
- nested_flatten,
- nested_pack,
- switch_to_uvloop,
- )
- from hivemind.utils.mpfuture import MPFuture
- from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
- DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert
- def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo): # -> ConnectionHandlerStub:
- return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
- @dataclass(frozen=True)
- class RemoteExpertInfo:
- uid: str
- peer_info: PeerInfo
- class RemoteExpert(nn.Module):
- """
- A simple module that runs forward/backward of an expert hosted on a remote machine.
- Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
- Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
- Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
- :param uid: unique expert identifier
- """
- def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
- super().__init__()
- self._info, self.p2p = expert_info, p2p
- self._rpc_info = None
- @property
- def uid(self):
- return self._info.uid
- @property
- def server_peer_info(self):
- return self._info.peer_info
- @property
- def stub(self) -> StubBase:
- return _get_expert_stub(self.p2p, self.server_peer_info)
- def forward(self, *args, **kwargs):
- """Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd."""
- assert len(kwargs) == len(self.info["keyword_names"]), f"Keyword args should be {self.info['keyword_names']}"
- kwargs = {key: kwargs[key] for key in self.info["keyword_names"]}
- # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
- forward_inputs = (args, kwargs)
- if not nested_compare(forward_inputs, self.info["forward_schema"]):
- raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
- flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.stub, self.info, *nested_flatten(forward_inputs))
- # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad
- return nested_pack(flat_outputs, structure=self.info["outputs_schema"])
- @property
- def info(self):
- if self._rpc_info is None:
- outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
- self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
- return self._rpc_info
- def extra_repr(self):
- return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
- class RemoteExpertWorker:
- """Local thread for managing async tasks related to RemoteExpert"""
- _task_queue: Queue = Queue()
- _event_thread: Optional[Thread] = None
- _pid: int = -1
- @classmethod
- def _run(cls):
- loop = switch_to_uvloop()
- async def receive_tasks():
- while True:
- cor, future = cls._task_queue.get()
- try:
- result = await cor
- except Exception as e:
- future.set_exception(e)
- continue
- if not future.cancelled():
- future.set_result(result)
- loop.run_until_complete(receive_tasks())
- @classmethod
- def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
- if cls._event_thread is None or cls._pid != os.getpid():
- cls._pid = os.getpid()
- cls._event_thread = Thread(target=cls._run, daemon=True)
- cls._event_thread.start()
- future = Future()
- cls._task_queue.put((coro, future))
- if return_future:
- return future
- result = future.result()
- return result
- @classmethod
- def _spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], p2p: P2P) -> List[Optional[RemoteExpert]]:
- experts: List[Optional[RemoteExpert]] = []
- for i in infos:
- if i is not None:
- experts.append(RemoteExpert(i, p2p))
- else:
- experts.append(None)
- return experts
- @classmethod
- def spawn_experts(
- cls, infos: Union[Sequence[Optional[RemoteExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
- ) -> Union[List[Optional[RemoteExpert]], Future]:
- if return_future:
- async def _unpack(infos_future: MPFuture, dht: DHT):
- p2p = await dht.replicate_p2p()
- return cls._spawn_experts(await infos_future, p2p)
- return cls.run_coroutine(_unpack(infos, dht), return_future)
- p2p = cls.run_coroutine(dht.replicate_p2p())
- return cls._spawn_experts(infos, p2p)
- @classmethod
- def batch_spawn_experts(
- cls,
- infos: Union[Sequence[Sequence[Optional[RemoteExpertInfo]]], MPFuture],
- dht: DHT,
- return_future: bool = False,
- ) -> Union[List[List[Optional[RemoteExpert]]], Future]:
- if return_future:
- async def _unpack(infos_future: MPFuture, dht: DHT):
- p2p = await dht.replicate_p2p()
- return [cls._spawn_experts(i, p2p) for i in await infos_future]
- return cls.run_coroutine(_unpack(infos, dht), return_future)
- return [cls.spawn_experts(exps, dht) for exps in infos]
- async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
- split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
- grad_inputs = await stub.rpc_backward_stream(
- amap_in_executor(
- lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
- iter_as_aiter(split),
- ),
- )
- return await gather_from_streaming(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
- async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
- grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
- runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
- )
- return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
- async def expert_backward(
- uid: str, inputs_and_grads: Sequence[torch.Tensor], compressions: Iterable, stub
- ) -> List[torch.Tensor]:
- serialized_tensors = (
- serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs_and_grads, compressions)
- )
- size = 0
- for t in inputs_and_grads:
- size += t.element_size() * t.nelement()
- if size >= DEFAULT_MAX_MSG_SIZE:
- return await _backward_stream(uid, serialized_tensors, stub)
- else:
- return await _backward_unary(uid, serialized_tensors, stub)
- async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
- split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
- outputs = await stub.rpc_forward_stream(
- amap_in_executor(
- lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor]),
- iter_as_aiter(split),
- ),
- )
- return await gather_from_streaming(outputs, lambda r: r.tensors, deserialize_torch_tensor)
- async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
- outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
- runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors))
- )
- return [deserialize_torch_tensor(t) for t in outputs.tensors]
- async def expert_forward(uid: str, inputs: Sequence[torch.Tensor], compressions: Iterable, stub) -> List[torch.Tensor]:
- serialized_tensors = (
- serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs, compressions)
- )
- size = 0
- for t in inputs:
- size += t.element_size() * t.nelement()
- if size >= DEFAULT_MAX_MSG_SIZE:
- return await _forward_stream(uid, serialized_tensors, stub)
- else:
- return await _forward_unary(uid, serialized_tensors, stub)
- class _RemoteModuleCall(torch.autograd.Function):
- """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
- @classmethod
- def forward(
- cls,
- ctx,
- dummy: torch.Tensor,
- uid: str,
- stub, #: ConnectionHandlerStub,
- info: Dict[str, Any],
- *inputs: torch.Tensor,
- ) -> Tuple[torch.Tensor, ...]:
- # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
- # detach to avoid pickling the computation graph
- inputs = tuple(tensor.cpu().detach() for tensor in inputs)
- ctx.uid, ctx.stub, ctx.info = uid, stub, info
- ctx.save_for_backward(*inputs)
- deserialized_outputs = RemoteExpertWorker.run_coroutine(
- expert_forward(uid, inputs, (p.compression for p in nested_flatten(info["forward_schema"])), stub)
- )
- return tuple(deserialized_outputs)
- @classmethod
- @once_differentiable
- def backward(cls, ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]:
- grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
- inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
- backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
- deserialized_grad_inputs = RemoteExpertWorker.run_coroutine(
- expert_backward(ctx.uid, inputs_and_grad_outputs, (p.compression for p in backward_schema), ctx.stub)
- )
- return (DUMMY, None, None, None, *deserialized_grad_inputs)
|