|
@@ -1,9 +1,16 @@
|
|
|
import contextlib
|
|
|
-from typing import AsyncIterator, Dict, Sequence, Optional, List
|
|
|
+from typing import AsyncIterator, Dict, List, Optional, Sequence
|
|
|
|
|
|
import torch
|
|
|
-from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, \
|
|
|
- serialize_torch_tensor, MSGPackSerializer
|
|
|
+from hivemind import (
|
|
|
+ DHT,
|
|
|
+ MSGPackSerializer,
|
|
|
+ P2PContext,
|
|
|
+ TensorDescriptor,
|
|
|
+ deserialize_torch_tensor,
|
|
|
+ nested_flatten,
|
|
|
+ serialize_torch_tensor,
|
|
|
+)
|
|
|
from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
from hivemind.proto import runtime_pb2
|
|
@@ -231,7 +238,7 @@ async def _rpc_forward(*flat_tensors: torch.Tensor, requested_backends: Sequence
|
|
|
for backend, prompt in zip(requested_backends, prompts):
|
|
|
(hidden_states,) = await backend.forward_pool.submit_task(hidden_states)
|
|
|
if not is_dummy(prompt):
|
|
|
- hidden_states[:, :min(seq_length, prompt.shape[1]), ...] += prompt
|
|
|
+ hidden_states[:, : min(seq_length, prompt.shape[1]), ...] += prompt
|
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
|
assert hidden_states.ndim == 3, f"{type(backend)} must return a list with a single 3d tensor of hidden states"
|
|
|
|
|
@@ -256,7 +263,7 @@ async def _rpc_backward(inputs: torch.Tensor, prompts: torch.Tensor, grad_output
|
|
|
for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
|
|
|
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
|
|
|
if not is_dummy(prompt):
|
|
|
- inputs = inputs.clone() # TODO
|
|
|
+ inputs = inputs.clone() # TODO
|
|
|
inputs[:, :pre_seq_len] += prompt
|
|
|
(inputs,) = await backend.forward_pool.submit_task(inputs)
|
|
|
assert isinstance(inputs, torch.Tensor)
|