|
@@ -1,7 +1,9 @@
|
|
|
+import asyncio
|
|
|
import contextlib
|
|
|
from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
+from async_timeout import timeout
|
|
|
from hivemind import (
|
|
|
DHT,
|
|
|
MSGPackSerializer,
|
|
@@ -37,13 +39,19 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
self,
|
|
|
dht: DHT,
|
|
|
module_backends: Dict[str, TransformerBackend],
|
|
|
+ *,
|
|
|
inference_max_length: int,
|
|
|
+ request_timeout: float,
|
|
|
+ session_timeout: float,
|
|
|
+ step_timeout: float,
|
|
|
task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
|
|
|
):
|
|
|
super().__init__(dht, module_backends)
|
|
|
for module_backend in self.module_backends.values():
|
|
|
assert isinstance(module_backend, TransformerBackend)
|
|
|
self.inference_max_length = inference_max_length
|
|
|
+ self.request_timeout = request_timeout
|
|
|
+ self.session_timeout, self.step_timeout = session_timeout, step_timeout
|
|
|
self._prioritizer = task_prioritizer
|
|
|
|
|
|
async def _gather_inputs(
|
|
@@ -76,227 +84,240 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
|
|
|
|
|
|
- request = await anext(requests)
|
|
|
- requested_uids = self._check_uids(request.uid)
|
|
|
- self._log_request("rpc_inference.open", requested_uids, context)
|
|
|
- try:
|
|
|
- metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
- max_length = metadata.get("max_length")
|
|
|
- points = metadata.get("points", 0)
|
|
|
+ async with timeout(self.session_timeout):
|
|
|
+ request = await asyncio.wait_for(anext(requests), self.step_timeout)
|
|
|
+ requested_uids = self._check_uids(request.uid)
|
|
|
+ self._log_request("rpc_inference.open", requested_uids, context)
|
|
|
+ try:
|
|
|
+ metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
+ max_length = metadata.get("max_length")
|
|
|
+ points = metadata.get("points", 0)
|
|
|
+
|
|
|
+ if not requested_uids:
|
|
|
+ raise ValueError("User must specify at least one block for inference, but got none")
|
|
|
+ assert isinstance(
|
|
|
+ max_length, int
|
|
|
+ ), f"rpc_inference metadata must contain int max_length, got {max_length}"
|
|
|
+ assert isinstance(
|
|
|
+ points, (float, int)
|
|
|
+ ), f"rpc_inference should have number of points as a number or None, got {points}"
|
|
|
+ if not 0 <= max_length <= self.inference_max_length:
|
|
|
+ raise ValueError(
|
|
|
+ f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
|
|
|
+ )
|
|
|
|
|
|
- if not requested_uids:
|
|
|
- raise ValueError("User must specify at least one block for inference, but got none")
|
|
|
- assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
|
|
|
- assert isinstance(
|
|
|
- points, (float, int)
|
|
|
- ), f"rpc_inference should have number of points as a number or None, got {points}"
|
|
|
- if not 0 <= max_length <= self.inference_max_length:
|
|
|
- raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
|
|
|
-
|
|
|
- point_per_piece = points / max_length if max_length > 0 else 0.0
|
|
|
- batch_size = request.tensors[0].size[0] if request.tensors else 1
|
|
|
-
|
|
|
- cache_metadata = torch.tensor(
|
|
|
- [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
|
|
|
- ) # [cache_handle, prefix_length]
|
|
|
- prefix_length = 0
|
|
|
-
|
|
|
- async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
|
|
|
- assert len(cache_handles) == len(requested_backends)
|
|
|
- while request.tensors: # iterate while user is willing to supply tensors
|
|
|
- hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
-
|
|
|
- # Cast inputs to backend dtype
|
|
|
- hidden_states = hidden_states.to(requested_backends[0].dtype)
|
|
|
- assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
|
|
|
-
|
|
|
- # parse deep prompts (optional argument)
|
|
|
- if prompts is None or is_dummy(prompts) or is_dummy(prompts):
|
|
|
- prompts = [DUMMY] * len(requested_backends)
|
|
|
- else:
|
|
|
- prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
-
|
|
|
- if not (len(requested_backends) == len(prompts)):
|
|
|
- raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
|
|
|
-
|
|
|
- length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
|
|
|
- if prefix_length + length_increment > max_length:
|
|
|
- raise ValueError(
|
|
|
- f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
|
|
|
- f" exceeds pre-allocated maximum {max_length}"
|
|
|
- )
|
|
|
+ point_per_piece = points / max_length if max_length > 0 else 0.0
|
|
|
+ batch_size = request.tensors[0].size[0] if request.tensors else 1
|
|
|
|
|
|
- # run request tensors through all requested modules, update caches
|
|
|
- for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
|
|
|
- if not is_dummy(prompt):
|
|
|
- hidden_states[:, : prompt.shape[1]] += prompt
|
|
|
-
|
|
|
- cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
|
|
|
- assert isinstance(
|
|
|
- hidden_states, torch.Tensor
|
|
|
- ), f"hidden states must be tensor, got {type(hidden_states)}"
|
|
|
- assert (
|
|
|
- hidden_states.ndim == 3
|
|
|
- ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
- assert isinstance(
|
|
|
- backend.inference_pool, PrioritizedTaskPool
|
|
|
- ), "petals support only prioritized pools"
|
|
|
- priority = self._prioritizer.prioritize(
|
|
|
- cache_metadata,
|
|
|
- hidden_states,
|
|
|
- hypo_ids,
|
|
|
- points=point_per_piece / len(requested_backends),
|
|
|
- backend=backend,
|
|
|
- type="inference",
|
|
|
- )
|
|
|
- (hidden_states,) = await backend.inference_pool.submit_task(
|
|
|
- cache_metadata, hidden_states, hypo_ids, priority=priority
|
|
|
- )
|
|
|
+ cache_metadata = torch.tensor(
|
|
|
+ [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
|
|
|
+ ) # [cache_handle, prefix_length]
|
|
|
+ prefix_length = 0
|
|
|
|
|
|
- # serialize and send last layer outputs
|
|
|
- yield runtime_pb2.ExpertResponse(
|
|
|
- tensors=[
|
|
|
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(
|
|
|
- (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
|
|
|
- )
|
|
|
+ async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
|
|
|
+ assert len(cache_handles) == len(requested_backends)
|
|
|
+ while request.tensors: # iterate while user is willing to supply tensors
|
|
|
+ hidden_states, prompts, hypo_ids = [
|
|
|
+ deserialize_torch_tensor(tensor) for tensor in request.tensors
|
|
|
]
|
|
|
- )
|
|
|
|
|
|
- # prepare for next step
|
|
|
- prefix_length += hidden_states.shape[1]
|
|
|
- request = await (anext(requests))
|
|
|
- finally:
|
|
|
- self._log_request("rpc_inference.close", requested_uids, context)
|
|
|
+ # Cast inputs to backend dtype
|
|
|
+ hidden_states = hidden_states.to(requested_backends[0].dtype)
|
|
|
+ assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
|
|
|
+
|
|
|
+ # parse deep prompts (optional argument)
|
|
|
+ if prompts is None or is_dummy(prompts) or is_dummy(prompts):
|
|
|
+ prompts = [DUMMY] * len(requested_backends)
|
|
|
+ else:
|
|
|
+ prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
|
|
|
+
|
|
|
+ if not (len(requested_backends) == len(prompts)):
|
|
|
+ raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
|
|
|
+
|
|
|
+ length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
|
|
|
+ if prefix_length + length_increment > max_length:
|
|
|
+ raise ValueError(
|
|
|
+ f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
|
|
|
+ f" exceeds pre-allocated maximum {max_length}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # run request tensors through all requested modules, update caches
|
|
|
+ for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
|
|
|
+ if not is_dummy(prompt):
|
|
|
+ hidden_states[:, : prompt.shape[1]] += prompt
|
|
|
+
|
|
|
+ cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
|
|
|
+ assert isinstance(
|
|
|
+ hidden_states, torch.Tensor
|
|
|
+ ), f"hidden states must be tensor, got {type(hidden_states)}"
|
|
|
+ assert (
|
|
|
+ hidden_states.ndim == 3
|
|
|
+ ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
|
|
|
+ assert isinstance(
|
|
|
+ backend.inference_pool, PrioritizedTaskPool
|
|
|
+ ), "petals support only prioritized pools"
|
|
|
+ priority = self._prioritizer.prioritize(
|
|
|
+ cache_metadata,
|
|
|
+ hidden_states,
|
|
|
+ hypo_ids,
|
|
|
+ points=point_per_piece / len(requested_backends),
|
|
|
+ backend=backend,
|
|
|
+ type="inference",
|
|
|
+ )
|
|
|
+ (hidden_states,) = await backend.inference_pool.submit_task(
|
|
|
+ cache_metadata, hidden_states, hypo_ids, priority=priority
|
|
|
+ )
|
|
|
+
|
|
|
+ # serialize and send last layer outputs
|
|
|
+ yield runtime_pb2.ExpertResponse(
|
|
|
+ tensors=[
|
|
|
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
+ for result, proto in zip(
|
|
|
+ (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
|
|
|
+ )
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ # prepare for next step
|
|
|
+ prefix_length += hidden_states.shape[1]
|
|
|
+ request = await asyncio.wait_for(anext(requests), self.step_timeout)
|
|
|
+ finally:
|
|
|
+ self._log_request("rpc_inference.close", requested_uids, context)
|
|
|
|
|
|
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
- # Parse request and prepare backends
|
|
|
- flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
- requested_uids = self._check_uids(request.uid)
|
|
|
- self._log_request("rpc_forward", requested_uids, context)
|
|
|
-
|
|
|
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
- metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
- points = metadata.get("points", 0)
|
|
|
- assert isinstance(
|
|
|
- points, (float, int)
|
|
|
- ), f"rpc_forward should have number of points as number or None, got {points}"
|
|
|
-
|
|
|
- hidden_states = await _rpc_forward(
|
|
|
- *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
- )
|
|
|
- assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
+ async with timeout(self.request_timeout):
|
|
|
+ # Parse request and prepare backends
|
|
|
+ flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
+ requested_uids = self._check_uids(request.uid)
|
|
|
+ self._log_request("rpc_forward", requested_uids, context)
|
|
|
|
|
|
- # Serialize output and respond to client
|
|
|
- return runtime_pb2.ExpertResponse(
|
|
|
- tensors=[
|
|
|
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
- ]
|
|
|
- )
|
|
|
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
+ metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
+ points = metadata.get("points", 0)
|
|
|
+ assert isinstance(
|
|
|
+ points, (float, int)
|
|
|
+ ), f"rpc_forward should have number of points as number or None, got {points}"
|
|
|
+
|
|
|
+ hidden_states = await _rpc_forward(
|
|
|
+ *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
+ )
|
|
|
+ assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
+
|
|
|
+ # Serialize output and respond to client
|
|
|
+ return runtime_pb2.ExpertResponse(
|
|
|
+ tensors=[
|
|
|
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
+ for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
+ ]
|
|
|
+ )
|
|
|
|
|
|
async def rpc_forward_stream(
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
- # Parse requests and prepare backends
|
|
|
- uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
|
|
|
- requested_uids = self._check_uids(uid_str)
|
|
|
- self._log_request("rpc_forward_stream", requested_uids, context)
|
|
|
-
|
|
|
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
- points = metadata.get("points", 0)
|
|
|
- assert isinstance(
|
|
|
- points, (float, int)
|
|
|
- ), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
|
-
|
|
|
- hidden_states = await _rpc_forward(
|
|
|
- *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
- )
|
|
|
- assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
|
|
|
+ async with timeout(self.request_timeout):
|
|
|
+ # Parse requests and prepare backends
|
|
|
+ uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
|
|
|
+ requested_uids = self._check_uids(uid_str)
|
|
|
+ self._log_request("rpc_forward_stream", requested_uids, context)
|
|
|
|
|
|
- # Serialize the overall output
|
|
|
- serialized_output = [
|
|
|
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
- ]
|
|
|
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
+ points = metadata.get("points", 0)
|
|
|
+ assert isinstance(
|
|
|
+ points, (float, int)
|
|
|
+ ), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
- # Split the serialized_output for streaming and respond to client
|
|
|
- output_split = [
|
|
|
- part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
- ]
|
|
|
- async for part in as_aiter(*output_split):
|
|
|
- yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
+ hidden_states = await _rpc_forward(
|
|
|
+ *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
+ )
|
|
|
+ assert (
|
|
|
+ isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
|
|
|
+ ), "hidden_states must be a 3d tensor"
|
|
|
+
|
|
|
+ # Serialize the overall output
|
|
|
+ serialized_output = [
|
|
|
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
+ for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Split the serialized_output for streaming and respond to client
|
|
|
+ output_split = [
|
|
|
+ part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
+ ]
|
|
|
+ async for part in as_aiter(*output_split):
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
|
|
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
- # Parse requests and prepare backends
|
|
|
- flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
- requested_uids = self._check_uids(request.uid)
|
|
|
- self._log_request("rpc_backward", requested_uids, context)
|
|
|
-
|
|
|
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
- metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
- points = metadata.get("points", 0)
|
|
|
- assert isinstance(
|
|
|
- points, (float, int)
|
|
|
- ), f"rpc_backward should have number of points as number or None, got {points}"
|
|
|
-
|
|
|
- grads = await _rpc_backward(
|
|
|
- *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
- )
|
|
|
+ async with timeout(self.request_timeout):
|
|
|
+ # Parse requests and prepare backends
|
|
|
+ flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
|
|
|
+ requested_uids = self._check_uids(request.uid)
|
|
|
+ self._log_request("rpc_backward", requested_uids, context)
|
|
|
|
|
|
- # Modify grad_inputs_schema to support grad_prompts
|
|
|
- assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
+ metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
+ points = metadata.get("points", 0)
|
|
|
+ assert isinstance(
|
|
|
+ points, (float, int)
|
|
|
+ ), f"rpc_backward should have number of points as number or None, got {points}"
|
|
|
|
|
|
- grad_inputs_schema_with_prompts = (
|
|
|
- requested_backends[0].args_schema * len(grads),
|
|
|
- requested_backends[0].kwargs_schema,
|
|
|
- ) # TODO generalize
|
|
|
+ grads = await _rpc_backward(
|
|
|
+ *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
+ )
|
|
|
|
|
|
- # Serialize the overall grad_input and respond
|
|
|
- return runtime_pb2.ExpertResponse(
|
|
|
- tensors=[
|
|
|
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
|
|
|
- ]
|
|
|
- )
|
|
|
+ # Modify grad_inputs_schema to support grad_prompts
|
|
|
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
+
|
|
|
+ grad_inputs_schema_with_prompts = (
|
|
|
+ requested_backends[0].args_schema * len(grads),
|
|
|
+ requested_backends[0].kwargs_schema,
|
|
|
+ ) # TODO generalize
|
|
|
+
|
|
|
+ # Serialize the overall grad_input and respond
|
|
|
+ return runtime_pb2.ExpertResponse(
|
|
|
+ tensors=[
|
|
|
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
+ for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
|
|
|
+ ]
|
|
|
+ )
|
|
|
|
|
|
async def rpc_backward_stream(
|
|
|
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
- uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
|
|
|
- requested_uids = self._check_uids(uids_header)
|
|
|
- self._log_request("rpc_backward_stream", requested_uids, context)
|
|
|
-
|
|
|
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
- points = metadata.get("points", 0)
|
|
|
- assert isinstance(
|
|
|
- points, (float, int)
|
|
|
- ), f"rpc_backward_stream should have number of points as number or None, got {points}"
|
|
|
-
|
|
|
- grads = await _rpc_backward(
|
|
|
- *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
- )
|
|
|
+ async with timeout(self.request_timeout):
|
|
|
+ uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
|
|
|
+ requested_uids = self._check_uids(uids_header)
|
|
|
+ self._log_request("rpc_backward_stream", requested_uids, context)
|
|
|
+
|
|
|
+ requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
+ points = metadata.get("points", 0)
|
|
|
+ assert isinstance(
|
|
|
+ points, (float, int)
|
|
|
+ ), f"rpc_backward_stream should have number of points as number or None, got {points}"
|
|
|
+
|
|
|
+ grads = await _rpc_backward(
|
|
|
+ *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
+ )
|
|
|
+
|
|
|
+ # Modify grad_inputs_schema to support grad_prompts
|
|
|
+ assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
+ grad_inputs_schema_with_prompts = (
|
|
|
+ requested_backends[0].args_schema * len(grads),
|
|
|
+ requested_backends[0].kwargs_schema,
|
|
|
+ ) # TODO generalize
|
|
|
+
|
|
|
+ # Serialize the overall grad_inputs
|
|
|
+ serialized_grad_inputs = [
|
|
|
+ serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
+ for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
|
|
|
+ ]
|
|
|
+ # Split the serialized_grad_inputs for streaming and respond
|
|
|
+ output_split = [
|
|
|
+ part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
+ ]
|
|
|
|
|
|
- # Modify grad_inputs_schema to support grad_prompts
|
|
|
- assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
|
|
|
- grad_inputs_schema_with_prompts = (
|
|
|
- requested_backends[0].args_schema * len(grads),
|
|
|
- requested_backends[0].kwargs_schema,
|
|
|
- ) # TODO generalize
|
|
|
-
|
|
|
- # Serialize the overall grad_inputs
|
|
|
- serialized_grad_inputs = [
|
|
|
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
|
|
|
- for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
|
|
|
- ]
|
|
|
- # Split the serialized_grad_inputs for streaming and respond
|
|
|
- output_split = [
|
|
|
- part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
|
|
|
- ]
|
|
|
-
|
|
|
- async for part in as_aiter(*output_split):
|
|
|
- yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
+ async for part in as_aiter(*output_split):
|
|
|
+ yield runtime_pb2.ExpertResponse(tensors=[part])
|
|
|
|
|
|
def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
|
|
|
"""Check that the first request to rpc_inference is valid"""
|