Explorar o código

Workable version

artek0chumak %!s(int64=2) %!d(string=hai) anos
pai
achega
3a66dd19d4

+ 8 - 2
src/petals/client/inference_session.py

@@ -78,7 +78,7 @@ class _ServerInferenceSession:
     def step(
         self,
         new_hidden_states: torch.Tensor,
-        attention_mask: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
         prompts: Optional[torch.Tensor] = None,
         hypo_ids: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
@@ -103,6 +103,9 @@ class _ServerInferenceSession:
         else:
             assert len(hypo_ids) == len(new_hidden_states)
             assert hypo_ids.dtype == torch.int64
+            
+        if attention_mask is None:
+            attention_mask = DUMMY
 
         # serialize inputs and put them into the queue
         inputs = (new_hidden_states, attention_mask, prompts, hypo_ids)
@@ -214,7 +217,7 @@ class InferenceSession:
         return self
 
     def step(
-        self, inputs: torch.Tensor, attention_mask: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs
+        self, inputs: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, prompts: Optional[torch.Tensor] = None, **kwargs
     ) -> torch.Tensor:
         assert not self._closed
         if torch.is_grad_enabled():
@@ -225,6 +228,9 @@ class InferenceSession:
             prompts = DUMMY
         else:
             assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
+            
+        if attention_mask is None:
+            attention_mask = DUMMY
 
         inputs_device = inputs.device
         inputs_dtype = inputs.dtype

+ 2 - 1
src/petals/client/remote_forward_backward.py

@@ -118,6 +118,7 @@ async def run_remote_backward(
     stub: StubBase,
     rpc_info: RPCInfo,
     inputs: torch.Tensor,
+    attention_masks: torch.Tensor,
     grad_outputs: List[torch.Tensor],
     *extra_tensors: torch.Tensor,
     timeout: float,
@@ -131,7 +132,7 @@ async def run_remote_backward(
     """
 
     grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
-    inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
+    inputs_and_grad_outputs = tuple(nested_flatten((inputs, attention_masks, grad_outputs_cpu, *extra_tensors)))
 
     # Modify forward_schema to support prompts
     args_schema, kwargs_schema = rpc_info["forward_schema"]

+ 1 - 2
src/petals/client/remote_generation.py

@@ -179,9 +179,8 @@ class RemoteGenerationMixin:
                     hidden_state = torch.cat([prompts, hidden_state], dim=1)
                 hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
 
-                attention_mask = torch.ones((batch_size, seq_idx), device=hidden_state.device)
                 hidden_state = session.step(
-                    hidden_state, attention_mask, prompts=intermediate_prompts, hypo_ids=hypo_ids
+                    hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids
                 )[:, -1]
 
                 hidden_state = self.transformer.ln_f(hidden_state)

+ 1 - 1
src/petals/client/routing/sequence_manager.py

@@ -333,7 +333,7 @@ class _SequenceManagerUpdateThread(threading.Thread):
 
 def maybe_log_traceback(exc: Exception):
     traceback_level = logging.DEBUG if str(exc) or isinstance(exc, asyncio.TimeoutError) else logging.WARNING
-    logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
+    logger.log(logging.INFO, "See detailed traceback below:", exc_info=True)
 
 
 class MissingBlocksError(RuntimeError):

+ 14 - 13
src/petals/client/sequential_autograd.py

@@ -25,7 +25,7 @@ MAX_TOKENS_IN_BATCH = 1024
 
 async def sequential_forward(
     inputs: torch.Tensor,
-    attention_mask: torch.Tensor,
+    attention_masks: torch.Tensor,
     prompts: torch.Tensor,
     sequence_manager: RemoteSequenceManager,
     start_index: int = 0,
@@ -38,12 +38,12 @@ async def sequential_forward(
     """
 
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
-    assert isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 2, f"{type(attention_mask)}: {attention_mask.ndim}"
+    assert isinstance(attention_masks, torch.Tensor) and attention_masks.ndim == 2, f"{type(attention_masks)}: {attention_masks.ndim}"
 
     inputs_device = inputs.device
     inputs_dtype = inputs.dtype
     inputs = inputs.cpu()
-    attention_mask = attention_mask.cpu()
+    attention_masks = attention_masks.cpu()
     prompts = prompts.cpu()
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
@@ -71,7 +71,7 @@ async def sequential_forward(
                 span = sequences.popleft()
 
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                inputs_and_prompts = [inputs, attention_mask, prompts[span.start : span.end]]
+                inputs_and_prompts = [inputs, attention_masks, prompts[span.start : span.end]]
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
@@ -114,7 +114,7 @@ async def sequential_forward(
 async def sequential_backward(
     grad_outputs: Sequence[torch.Tensor],
     intermediate_inputs: List[torch.Tensor],
-    attention_mask: torch.Tensor,
+    attention_masks: torch.Tensor,
     prompts: torch.Tensor,
     forward_sequences: List[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
@@ -132,7 +132,7 @@ async def sequential_backward(
 
     grad_outputs = [tensor.cpu() for tensor in grad_outputs]
     intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
-    attention_mask = attention_mask.cpu()
+    attention_masks = attention_masks.cpu()
     prompts = prompts.cpu()
 
     grad_prompts_reversed = []
@@ -144,7 +144,7 @@ async def sequential_backward(
             try:
                 if attempt_no >= 1:
                     _, backup_inputs, backup_sequences = await sequential_forward(
-                        inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
+                        inputs, attention_masks, prompts, sequence_manager, start_index=span.start, end_index=span.end
                     )
                     assert len(backup_inputs) == len(backup_sequences)
                     assert backup_sequences[0].start == span.start
@@ -158,14 +158,14 @@ async def sequential_backward(
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 metadata = sequence_manager.get_request_metadata(
-                    "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
+                    "rpc_backward", span_uids, *inputs, attention_masks, *grad_outputs, peer_id=span.peer_id
                 )
                 grad_outputs, *span_grad_prompts = await run_remote_backward(
                     span_uids,
                     stub,
                     sequence_manager.rpc_info,
                     inputs,
-                    attention_mask,
+                    attention_masks,
                     grad_outputs,
                     prompts[span.start : span.end],
                     timeout=sequence_manager.request_timeout,
@@ -255,27 +255,28 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     @staticmethod
     def backward(ctx, grad_outputs: torch.Tensor):
         intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
+        attention_mask_batches: List[Sequence[torch.Tensor]] = ctx.attention_mask_batches
         forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
         ctx.sequence_manager.rpc_info  # lazy init
 
         batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
         grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
-        assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
+        assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences) == len(attention_mask_batches)
 
         outputs = RemoteExpertWorker.run_coroutine(
             _gather_backward(
                 grad_output_batches,
                 intermediate_input_batches,
-                ctx.attention_mask_batches,
+                attention_mask_batches,
                 ctx.prompt_batches,
                 forward_sequences,
                 ctx.sequence_manager,
             )
         )
         grad_input_batches = [output[0][0] for output in outputs]
-        grad_prompt_batches = [output[2] for output in outputs]
+        grad_prompt_batches = [output[1] for output in outputs]
 
         grad_inputs = torch.cat(grad_input_batches, dim=0)
         dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
         grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
-        return (grad_inputs, grad_prompts, None)
+        return (grad_inputs, None, grad_prompts, None)

+ 11 - 7
src/petals/server/handler.py

@@ -134,7 +134,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                 async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
                     assert len(cache_handles) == len(requested_backends)
                     while request.tensors:  # iterate while user is willing to supply tensors
-                        hidden_states, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
+                        hidden_states, attention_mask, prompts, hypo_ids = map(deserialize_torch_tensor, request.tensors)
 
                         # Cast inputs to backend dtype
                         hidden_states = hidden_states.to(requested_backends[0].dtype)
@@ -156,6 +156,9 @@ class TransformerConnectionHandler(ConnectionHandler):
                                 f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
                                 f" exceeds pre-allocated maximum {max_length}"
                             )
+                            
+                        if is_dummy(attention_mask):
+                            attention_mask = torch.ones((hidden_states.shape[0], prefix_length + length_increment), dtype=hypo_ids.dtype)
 
                         priority = self._prioritizer.prioritize(
                             hidden_states,
@@ -317,9 +320,9 @@ class TransformerConnectionHandler(ConnectionHandler):
     ) -> Sequence[runtime_pb2.Tensor]:
         """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema"""
         # Modify grad_inputs_schema to support grad_prompts
-        assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2)  # TODO generalize
+        assert len(requested_backends[0].args_schema) == 2 and len(grads) in (1, 2)  # TODO generalize
         flat_grads_schema = tuple(
-            nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema))
+            nested_flatten((requested_backends[0].args_schema[:1] * len(grads), requested_backends[0].kwargs_schema))
         )  # TODO generalize
 
         if metadata.get("output_compression") is not None:
@@ -435,7 +438,7 @@ async def _rpc_forward(
             attention_masks,
             priority=priority,
         )
-        assert isinstance(hidden_states, torch.Tensor)
+        assert isinstance(hidden_states, torch.Tensor), f"hidden_states is {hidden_states}"
         assert (
             hidden_states.ndim == 3
         ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
@@ -474,7 +477,7 @@ async def _rpc_backward(
         )
         (inputs,) = await backend.forward_pool.submit_task(inputs, attention_masks, priority=priority)
 
-        assert isinstance(inputs, torch.Tensor)
+        assert isinstance(inputs, torch.Tensor), f"inputs is {inputs}"
 
     if not is_dummy(prompts[-1]):
         inputs[:, : prompts[-1].shape[1]] += prompts[-1]
@@ -488,9 +491,10 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
-        (grad_outputs,) = await backend.backward_pool.submit_task(inp, attention_masks, grad_outputs, priority=priority)
+        grad_outputs = await backend.backward_pool.submit_task(inp, attention_masks, grad_outputs, priority=priority)
+        grad_outputs = grad_outputs[0]
 
-        assert isinstance(grad_outputs, torch.Tensor)
+        assert isinstance(grad_outputs, torch.Tensor), f"grad_outputs is {grad_outputs}"
         if not is_dummy(prompt):
             grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
 

+ 2 - 1
src/petals/server/throughput.py

@@ -155,9 +155,10 @@ def measure_compute_rps(
         elapsed = 0
         for step in range(n_steps + 1):
             dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
+            dummy_mask = torch.ones((n_tokens, 1), device=device, dtype=dtype)
 
             start_time = time.perf_counter()
-            _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache)
+            _, cache = block.forward(dummy_input, dummy_mask, use_cache=True, layer_past=cache)
             if step >= 1:  # Skip the 1st step to exclude the initialization time
                 elapsed += time.perf_counter() - start_time
         device_rps = n_steps * n_tokens / elapsed

+ 3 - 2
tests/test_block_exact_match.py

@@ -22,7 +22,8 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
         assert isinstance(remote_block, RemoteTransformerBlock)
 
         inputs = torch.randn(1, 8, config.hidden_size)
-        outputs_forward = remote_block(inputs)
+        attention_mask = torch.ones(1, 8)
+        outputs_forward = remote_block(inputs, attention_mask)
 
         outputs_inference = []
         with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
@@ -37,7 +38,7 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
         outputs_inference = torch.cat(outputs_inference, dim=1)
 
         ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
-        (outputs_local,) = ref_block(inputs)
+        (outputs_local,) = ref_block(inputs, attention_mask)
 
         assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
         assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)

+ 5 - 3
tests/test_chained_calls.py

@@ -28,14 +28,15 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
         load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32),
     ]
     inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
-    outputs_rpc = remote_blocks.forward(inputs)
+    attention_mask = torch.ones((1, seq_length))
+    outputs_rpc = remote_blocks.forward(inputs, attention_mask)
     outputs_rpc.sum().backward()
     grads_rpc = inputs.grad
 
     inputs.grad = None
     hidden_states = inputs
     for ref_block in ref_blocks:
-        hidden_states = ref_block.forward(hidden_states)[0]
+        hidden_states = ref_block.forward(hidden_states, attention_mask)[0]
     outputs_ref = hidden_states
     outputs_ref.sum().backward()
     grads_ref = inputs.grad
@@ -52,6 +53,7 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
     assert isinstance(remote_blocks, RemoteSequential)
 
     inputs = torch.randn(1, 8, config.hidden_size)
+    attention_masks = torch.ones((1, 8))
 
     outputs_inference = []
     with remote_blocks.inference_session(max_length=inputs.shape[1]) as sess:
@@ -70,7 +72,7 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
         hidden_states = inputs[:, i : i + 1, :]
         for ref_block, cache in zip(ref_blocks, caches):
             with torch.no_grad():
-                hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)
+                hidden_states, new_cache = ref_block.forward(hidden_states, attention_masks[:, :i+1], use_cache=True, layer_past=cache)
                 new_caches.append(new_cache)
 
         outputs_ref.append(hidden_states)

+ 8 - 6
tests/test_remote_sequential.py

@@ -18,11 +18,12 @@ def test_remote_sequential():
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
     test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
+    test_attention_mask = torch.ones((1, 5))
     grad_proj = torch.randn(1, 5, config.hidden_size)
 
     sequential = RemoteSequential(config, dht)
 
-    full_outputs = sequential(test_inputs)
+    full_outputs = sequential(test_inputs, test_attention_mask)
     (full_outputs * grad_proj).sum().backward()
     assert test_inputs.grad is not None
     full_grad = test_inputs.grad.clone()
@@ -35,11 +36,11 @@ def test_remote_sequential():
     for m in sequential, first_half, second_half:
         assert isinstance(repr(m), str)
 
-    hidden = first_half(test_inputs)
+    hidden = first_half(test_inputs, test_attention_mask)
     assert isinstance(hidden, torch.Tensor)
     assert hidden.shape == test_inputs.shape
     assert hidden.requires_grad
-    second_half_outputs = second_half(hidden)
+    second_half_outputs = second_half(hidden, test_attention_mask)
     assert torch.allclose(second_half_outputs, full_outputs, atol=1e-4)
 
     (second_half_outputs * grad_proj).sum().backward()
@@ -52,7 +53,7 @@ def test_remote_sequential():
     )
 
     test_inputs.grad = None
-    approx_outputs = lossy_sequential(test_inputs)
+    approx_outputs = lossy_sequential(test_inputs, test_attention_mask)
     (approx_outputs * grad_proj).sum().backward()
 
     assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
@@ -89,6 +90,7 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
     remote_sequential = RemoteSequential(config, dht)
 
     inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
+    attention_mask = torch.ones((batch_size, seq_len + pre_seq_len))
     output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)
     input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1)
     intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
@@ -99,7 +101,7 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
     inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1)
     assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size)
 
-    outputs = remote_sequential(inputs_with_prompts, prompts=intermediate_prompts)
+    outputs = remote_sequential(inputs_with_prompts, attention_mask, prompts=intermediate_prompts)
 
     (outputs * output_proj).sum().backward()
     assert intermediate_prompts.grad is not None
@@ -116,7 +118,7 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
         outputs_ref[:, : block_prompt.shape[1]] += block_prompt
 
         block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
-        (outputs_ref,) = block(outputs_ref)
+        (outputs_ref,) = block(outputs_ref, attention_mask)
 
     assert torch.allclose(outputs_ref, outputs, atol=1e-3)
 

+ 1 - 1
tests/test_sequence_manager.py

@@ -35,7 +35,7 @@ def test_sequence_manager_basics(mode: str):
     assert sequential.sequence_manager.is_alive()
     assert sequential.sequence_manager._thread.ready.is_set()
     assert not shutdown_evt.is_set()
-    sequential(torch.randn(1, 2, config.hidden_size))
+    sequential(torch.randn(1, 2, config.hidden_size), torch.ones((1, 2)))
 
     sequential.sequence_manager.shutdown()
     del sequential

+ 8 - 4
tests/test_tensor_parallel.py

@@ -31,13 +31,17 @@ def test_tp_block(devices, custom_config):
     test_prefix2 = test_prefix1.detach().clone().requires_grad_(True)
     grad_proj = torch.rand_like(test_inputs1)
 
-    y_prefix_ref, layer_past = block(test_prefix1, use_cache=True)
-    y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past)
+    test_attention_mask1 = torch.ones((batch_size, prefix_length), device=devices[0])
+    y_prefix_ref, layer_past = block(test_prefix1, test_attention_mask1, use_cache=True)
+    test_attention_mask1 = torch.ones((batch_size, prefix_length + 3), device=devices[0])
+    y_ref, cache_ref = block(test_inputs1, test_attention_mask1, use_cache=True, layer_past=layer_past)
     y_ref.backward(grad_proj)
 
     block_tp = TensorParallel(block, devices, config=tp_config)
-    y_prefix, layer_past = block_tp(test_prefix2, use_cache=True)
-    y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
+    test_attention_mask2 = torch.ones((batch_size, prefix_length), device=devices[0])
+    y_prefix, layer_past = block_tp(test_prefix2, test_attention_mask2, use_cache=True)
+    test_attention_mask2 = torch.ones((batch_size, prefix_length + 3), device=devices[0])
+    y_ours, cache_ours = block_tp(test_inputs2, test_attention_mask2, use_cache=True, layer_past=layer_past)
     y_ours.backward(grad_proj)
 
     assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-5)