|
@@ -1,5 +1,5 @@
|
|
|
import contextlib
|
|
|
-from typing import AsyncIterator, Dict, List, Sequence, Union, Tuple, Iterable
|
|
|
+from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
from hivemind import (
|
|
@@ -7,19 +7,19 @@ from hivemind import (
|
|
|
MSGPackSerializer,
|
|
|
P2PContext,
|
|
|
TensorDescriptor,
|
|
|
- deserialize_torch_tensor,
|
|
|
deserialize_tensor_stream,
|
|
|
+ deserialize_torch_tensor,
|
|
|
nested_flatten,
|
|
|
serialize_torch_tensor,
|
|
|
)
|
|
|
from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
-from hivemind.utils.asyncio import anext, amap_in_executor, as_aiter
|
|
|
+from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
|
|
|
from hivemind.utils.streaming import split_for_streaming
|
|
|
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
|
|
-from src.server.backend import MAX_LENGTH, TransformerBackend, PrioritizedTaskPool
|
|
|
+from src.server.backend import MAX_LENGTH, PrioritizedTaskPool, TransformerBackend
|
|
|
from src.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
|
|
@@ -55,7 +55,6 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
inputs = await deserialize_tensor_stream(tensors_stream)
|
|
|
return expert_uid, inputs, metadata
|
|
|
|
|
|
-
|
|
|
async def rpc_inference(
|
|
|
self,
|
|
|
requests: AsyncIterator[runtime_pb2.ExpertRequest],
|
|
@@ -80,7 +79,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
while request.tensors: # iterate while user is willing to supply tensors
|
|
|
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
- pollen = metadata.get("pollen", 0.0)
|
|
|
+ dust = metadata.get("__dust", 0.0)
|
|
|
|
|
|
# Cast inputs to backend dtype
|
|
|
hidden_states = [tensor.to(requested_backends[0].dtype) for tensor in hidden_states]
|
|
@@ -92,7 +91,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
len(hidden_states) == 1 and hidden_states[0].ndim == 3
|
|
|
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
if isinstance(backend.inference_pool, PrioritizedTaskPool):
|
|
|
- hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states, pollen)
|
|
|
+ hidden_states = await backend.inference_pool.submit_task(
|
|
|
+ cache_metadata, *hidden_states, dust
|
|
|
+ )
|
|
|
else:
|
|
|
hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
|
|
|
assert isinstance(hidden_states, (list, tuple))
|
|
@@ -120,9 +121,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_uids = self._check_uids(request.uid)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
- pollen = metadata.get("pollen", 0.0)
|
|
|
+ dust = metadata.get("__dust", 0.0)
|
|
|
|
|
|
- hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends, pollen=pollen)
|
|
|
+ hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends, dust=dust)
|
|
|
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
|
|
|
# Serialize output and respond to client
|
|
@@ -141,7 +142,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_uids = self._check_uids(uid_str)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
- hidden_states = await _rpc_forward(*flat_inputs, requested_backends=requested_backends, pollen=metadata.get("pollen", 0.0))
|
|
|
+ hidden_states = await _rpc_forward(
|
|
|
+ *flat_inputs, requested_backends=requested_backends, dust=metadata.get("__dust", 0.0)
|
|
|
+ )
|
|
|
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
|
|
|
# Serialize the overall output
|
|
@@ -163,9 +166,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_uids = self._check_uids(request.uid)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
- pollen = metadata.get("pollen", 0.0)
|
|
|
+ dust = metadata.get("__dust", 0.0)
|
|
|
|
|
|
- grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends, pollen=pollen)
|
|
|
+ grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends, dust=dust)
|
|
|
|
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
|
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
@@ -191,7 +194,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
requested_uids = self._check_uids(uids_header)
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
|
|
- grads = await _rpc_backward(*flat_tensors, requested_backends=requested_backends, pollen=metadata.get("pollen", 0.0))
|
|
|
+ grads = await _rpc_backward(
|
|
|
+ *flat_tensors, requested_backends=requested_backends, dust=metadata.get("__dust", 0.0)
|
|
|
+ )
|
|
|
|
|
|
# Modify grad_inputs_schema to support grad_prompts
|
|
|
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
@@ -242,7 +247,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
yield handles
|
|
|
|
|
|
|
|
|
-async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], pollen: float = 0.0) -> torch.Tensor:
|
|
|
+async def _rpc_forward(
|
|
|
+ *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], dust: float = 0.0
|
|
|
+) -> torch.Tensor:
|
|
|
"""
|
|
|
Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
|
|
|
|
|
@@ -269,7 +276,7 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
|
|
|
if not is_dummy(prompt):
|
|
|
hidden_states[:, :pre_seq_len] += prompt
|
|
|
if isinstance(backend.forward_pool, PrioritizedTaskPool):
|
|
|
- (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, pollen)
|
|
|
+ (hidden_states,) = await backend.forward_pool.submit_task(hidden_states, dust)
|
|
|
else:
|
|
|
(hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
|
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
@@ -282,7 +289,7 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
|
|
|
|
|
|
|
|
|
async def _rpc_backward(
|
|
|
- *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], pollen: float = 0.0
|
|
|
+ *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], dust: float = 0.0
|
|
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
|
inputs, grad_outputs, *prompts = flat_tensors
|
|
|
# Cast inputs & grad outputs to backend dtype
|
|
@@ -307,7 +314,7 @@ async def _rpc_backward(
|
|
|
inter_inputs.append(inputs)
|
|
|
|
|
|
if isinstance(backend.forward_pool, PrioritizedTaskPool):
|
|
|
- (inputs,) = await backend.forward_pool.submit_task(inputs, pollen / 2.0)
|
|
|
+ (inputs,) = await backend.forward_pool.submit_task(inputs, dust / 2.0)
|
|
|
else:
|
|
|
(inputs,) = await backend.forward_pool.submit_task(inputs)
|
|
|
|
|
@@ -322,7 +329,7 @@ async def _rpc_backward(
|
|
|
# Run a chain of requested backends
|
|
|
for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
|
|
|
if isinstance(backend.backward_pool, PrioritizedTaskPool):
|
|
|
- (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, pollen / 2.0)
|
|
|
+ (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, dust / 2.0)
|
|
|
else:
|
|
|
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs)
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|