|
@@ -141,6 +141,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
max_length = metadata.get("max_length")
|
|
max_length = metadata.get("max_length")
|
|
|
|
+ active_adapter = metadata.get("active_adapter", "")
|
|
points = metadata.get("points", 0)
|
|
points = metadata.get("points", 0)
|
|
session_id = metadata.get("session_id")
|
|
session_id = metadata.get("session_id")
|
|
|
|
|
|
@@ -201,7 +202,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
)
|
|
)
|
|
|
|
|
|
inference_infos = tuple(
|
|
inference_infos = tuple(
|
|
- InferenceMetadata(uid, prefix_length, tuple(handles))
|
|
|
|
|
|
+ InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
|
|
for uid, handles in zip(requested_uids, cache_handles)
|
|
for uid, handles in zip(requested_uids, cache_handles)
|
|
)
|
|
)
|
|
|
|
|
|
@@ -354,13 +355,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
|
+ active_adapter = metadata.get("active_adapter", "")
|
|
points = metadata.get("points", 0)
|
|
points = metadata.get("points", 0)
|
|
assert isinstance(
|
|
assert isinstance(
|
|
points, (float, int)
|
|
points, (float, int)
|
|
), f"rpc_forward should have number of points as number or None, got {points}"
|
|
), f"rpc_forward should have number of points as number or None, got {points}"
|
|
|
|
|
|
hidden_states = await _rpc_forward(
|
|
hidden_states = await _rpc_forward(
|
|
- *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
|
|
|
+ *flat_inputs,
|
|
|
|
+ requested_backends=requested_backends,
|
|
|
|
+ prioritizer=self._prioritizer,
|
|
|
|
+ active_adapter=active_adapter,
|
|
|
|
+ points=points,
|
|
)
|
|
)
|
|
return runtime_pb2.ExpertResponse(
|
|
return runtime_pb2.ExpertResponse(
|
|
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
|
|
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
|
|
@@ -376,13 +382,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
self._log_request("rpc_forward_stream", requested_uids, context)
|
|
self._log_request("rpc_forward_stream", requested_uids, context)
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
+ active_adapter = metadata.get("active_adapter", "")
|
|
points = metadata.get("points", 0)
|
|
points = metadata.get("points", 0)
|
|
assert isinstance(
|
|
assert isinstance(
|
|
points, (float, int)
|
|
points, (float, int)
|
|
), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
hidden_states = await _rpc_forward(
|
|
hidden_states = await _rpc_forward(
|
|
- *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
|
|
|
+ *flat_inputs,
|
|
|
|
+ requested_backends=requested_backends,
|
|
|
|
+ prioritizer=self._prioritizer,
|
|
|
|
+ active_adapter=active_adapter,
|
|
|
|
+ points=points,
|
|
)
|
|
)
|
|
|
|
|
|
# Split the serialized_output for streaming and respond to client
|
|
# Split the serialized_output for streaming and respond to client
|
|
@@ -422,13 +433,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
|
+ active_adapter = metadata.get("active_adapter", "")
|
|
points = metadata.get("points", 0)
|
|
points = metadata.get("points", 0)
|
|
assert isinstance(
|
|
assert isinstance(
|
|
points, (float, int)
|
|
points, (float, int)
|
|
), f"rpc_backward should have number of points as number or None, got {points}"
|
|
), f"rpc_backward should have number of points as number or None, got {points}"
|
|
|
|
|
|
grads = await _rpc_backward(
|
|
grads = await _rpc_backward(
|
|
- *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
|
|
|
+ *flat_tensors,
|
|
|
|
+ requested_backends=requested_backends,
|
|
|
|
+ prioritizer=self._prioritizer,
|
|
|
|
+ active_adapter=active_adapter,
|
|
|
|
+ points=points,
|
|
)
|
|
)
|
|
|
|
|
|
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
|
|
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
|
|
@@ -442,13 +458,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
self._log_request("rpc_backward_stream", requested_uids, context)
|
|
self._log_request("rpc_backward_stream", requested_uids, context)
|
|
|
|
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
|
|
|
+ active_adapter = metadata.get("active_adapter", "")
|
|
points = metadata.get("points", 0)
|
|
points = metadata.get("points", 0)
|
|
assert isinstance(
|
|
assert isinstance(
|
|
points, (float, int)
|
|
points, (float, int)
|
|
), f"rpc_backward_stream should have number of points as number or None, got {points}"
|
|
), f"rpc_backward_stream should have number of points as number or None, got {points}"
|
|
|
|
|
|
grads = await _rpc_backward(
|
|
grads = await _rpc_backward(
|
|
- *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
|
|
|
|
|
+ *flat_tensors,
|
|
|
|
+ requested_backends=requested_backends,
|
|
|
|
+ prioritizer=self._prioritizer,
|
|
|
|
+ active_adapter=active_adapter,
|
|
|
|
+ points=points,
|
|
)
|
|
)
|
|
# Split the serialized_grad_inputs for streaming and respond
|
|
# Split the serialized_grad_inputs for streaming and respond
|
|
for tensor in self._serialize_grads(grads, requested_backends, metadata):
|
|
for tensor in self._serialize_grads(grads, requested_backends, metadata):
|
|
@@ -553,6 +574,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
async def _rpc_forward(
|
|
async def _rpc_forward(
|
|
*flat_tensors: torch.Tensor,
|
|
*flat_tensors: torch.Tensor,
|
|
requested_backends: Sequence[TransformerBackend],
|
|
requested_backends: Sequence[TransformerBackend],
|
|
|
|
+ active_adapter: str = "",
|
|
prioritizer: TaskPrioritizerBase,
|
|
prioritizer: TaskPrioritizerBase,
|
|
points: int = 0,
|
|
points: int = 0,
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
@@ -585,6 +607,7 @@ async def _rpc_forward(
|
|
)
|
|
)
|
|
(hidden_states,) = await backend.forward_pool.submit_task(
|
|
(hidden_states,) = await backend.forward_pool.submit_task(
|
|
hidden_states,
|
|
hidden_states,
|
|
|
|
+ active_adapter,
|
|
priority=priority,
|
|
priority=priority,
|
|
)
|
|
)
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
assert isinstance(hidden_states, torch.Tensor)
|
|
@@ -598,6 +621,7 @@ async def _rpc_forward(
|
|
async def _rpc_backward(
|
|
async def _rpc_backward(
|
|
*flat_tensors: torch.Tensor,
|
|
*flat_tensors: torch.Tensor,
|
|
requested_backends: Sequence[TransformerBackend],
|
|
requested_backends: Sequence[TransformerBackend],
|
|
|
|
+ active_adapter: str = "",
|
|
prioritizer: TaskPrioritizerBase,
|
|
prioritizer: TaskPrioritizerBase,
|
|
points: int = 0,
|
|
points: int = 0,
|
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
|
@@ -623,7 +647,7 @@ async def _rpc_backward(
|
|
priority = prioritizer.prioritize(
|
|
priority = prioritizer.prioritize(
|
|
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
|
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
|
)
|
|
)
|
|
- (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
|
|
|
|
|
|
+ (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
|
|
|
|
|
|
assert isinstance(inputs, torch.Tensor)
|
|
assert isinstance(inputs, torch.Tensor)
|
|
|
|
|
|
@@ -639,7 +663,7 @@ async def _rpc_backward(
|
|
priority = prioritizer.prioritize(
|
|
priority = prioritizer.prioritize(
|
|
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
|
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
|
)
|
|
)
|
|
- (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
|
|
|
|
|
|
+ (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
|
|
|
|
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
assert isinstance(grad_outputs, torch.Tensor)
|
|
if not is_dummy(prompt):
|
|
if not is_dummy(prompt):
|