|
@@ -22,6 +22,7 @@ from hivemind.proto import runtime_pb2
|
|
|
from src.client.sequence_manager import RemoteSequenceManager
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
+from src.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
@@ -44,6 +45,7 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
max_length: int,
|
|
|
):
|
|
|
self.uid, self.rpc_info = uid, rpc_info
|
|
|
+ self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
|
|
|
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
|
|
|
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
|
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
@@ -69,12 +71,30 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
if not next_input_message.uid and not next_input_message.tensors:
|
|
|
break # this message means "done sending"
|
|
|
|
|
|
- def step(self, new_hidden_states: torch.Tensor, prompts: Optional[torch.Tensor] = None):
|
|
|
- """Inference step: send a chunk of input tesors and receive a chunk of outputs"""
|
|
|
+ def step(self,
|
|
|
+ new_hidden_states: torch.Tensor,
|
|
|
+ prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None):
|
|
|
+ """
|
|
|
+ Inference step: send a chunk of input tesors and receive a chunk of outputs
|
|
|
+ :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
|
|
|
+ if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]
|
|
|
+ """
|
|
|
if self.closed:
|
|
|
raise Exception("Session is closed, cannot perform step")
|
|
|
+ if prompts is None or is_dummy(prompts):
|
|
|
+ prompts = DUMMY
|
|
|
+ else:
|
|
|
+ assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]"
|
|
|
+ assert prompts.shape[0] == self.num_blocks
|
|
|
+ assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
|
|
|
+ assert prompts.shape[2] <= new_hidden_states.shape[1]
|
|
|
+ assert prompts.shape[3] == new_hidden_states.shape[2]
|
|
|
+
|
|
|
+ assert hypo_ids is None, "TODO implement hypo_ids here"
|
|
|
+ hypo_ids = torch.arange(len(new_hidden_states))
|
|
|
+
|
|
|
# serialize inputs and put them into the queue
|
|
|
- inputs = (new_hidden_states, prompts, torch.arange(len(new_hidden_states)))
|
|
|
+ inputs = (new_hidden_states, prompts, hypo_ids)
|
|
|
outputs_serialized = RemoteExpertWorker.run_coroutine(
|
|
|
self._step(
|
|
|
runtime_pb2.ExpertRequest(
|
|
@@ -161,12 +181,16 @@ class RemoteSequentialInferenceSession:
|
|
|
|
|
|
return self
|
|
|
|
|
|
- def step(self, inputs: torch.Tensor):
|
|
|
+ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
|
|
|
assert not self.closed
|
|
|
if torch.is_grad_enabled():
|
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
|
+ if prompts is None or is_dummy(prompts):
|
|
|
+ prompts = DUMMY
|
|
|
+ else:
|
|
|
+ assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
|
|
|
for session in self.inference_sessions:
|
|
|
- outputs = session.step(inputs)
|
|
|
+ outputs = session.step(inputs, prompts[self.chosen_spans[0].start: self.chosen_spans[0].end], **kwargs)
|
|
|
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
|
|
inputs = outputs
|
|
|
return inputs
|