artek0chumak 2 years ago
parent
commit
d2d0403a11

+ 3 - 4
src/petals/bloom/block.py

@@ -18,20 +18,19 @@ class WrappedBloomBlock(BloomBlock):
     def forward(
         self,
         hidden_states: torch.Tensor,
-        attention_mask: Optional[torch.Tensor] = None,
+        attention_mask: torch.Tensor,
         *args,
         alibi: Optional[torch.Tensor] = None,
         layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
         **kwargs
     ):
-        assert attention_mask is None
         batch_size, seq_length = hidden_states.shape[:2]
         past_length = 0 if layer_past is None else layer_past[0].shape[-1]
         if alibi is None:
             alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
-        causal_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
+        attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
         return super().forward(
-            hidden_states, *args, attention_mask=causal_mask, alibi=alibi, layer_past=layer_past, **kwargs
+            hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
         )
 
     def _prepare_attn_mask(

+ 3 - 3
src/petals/client/remote_forward_backward.py

@@ -88,8 +88,8 @@ async def run_remote_forward(
     # Modify forward_schema to support prompts
     args_schema, kwargs_schema = rpc_info["forward_schema"]
     # TODO: rm this assert when support arbitrary number of input tensors
-    assert len(args_schema) == 1 and len(inputs) == 2
-    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+    assert len(args_schema) == 2 and len(inputs) == 3
+    forward_schema_with_prompts = ((args_schema[0], args_schema[1], args_schema[0]), kwargs_schema)
 
     if not nested_compare(forward_inputs, forward_schema_with_prompts):
         raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
@@ -135,7 +135,7 @@ async def run_remote_backward(
 
     # Modify forward_schema to support prompts
     args_schema, kwargs_schema = rpc_info["forward_schema"]
-    assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
+    assert len(args_schema) == 2 and isinstance(inputs, torch.Tensor)
     # TODO generalize this
     prompts_schema = next(iter(args_schema))
     backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))

+ 4 - 4
src/petals/client/remote_model.py

@@ -184,16 +184,16 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
+        batch_size = inputs_embeds.shape[0]
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
-            batch_size = inputs_embeds.shape[0]
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
-        if attention_mask is None:
-            attention_mask = torch.ones((batch_size, input_shape[-1]), device=hidden_states.device)
-
         hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
+        
+        if attention_mask is None:
+            attention_mask = torch.ones((batch_size, hidden_states.size(1)), device=hidden_states.device)
 
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
             hidden_states = self.h(hidden_states, attention_mask, prompts=intermediate_prompts)

+ 2 - 2
src/petals/client/remote_sequential.py

@@ -51,10 +51,10 @@ class RemoteSequential(nn.Module):
             assert isinstance(sequence_manager.sequence_info.block_uids, tuple)
             self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids
 
-    def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
+    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor, prompts: torch.Tensor = DUMMY):
         assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
         assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version"
-        outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
+        outputs = _RemoteSequentialAutogradFunction.apply(inputs, attention_mask, prompts, self.sequence_manager)
         return outputs
 
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:

+ 20 - 11
src/petals/client/sequential_autograd.py

@@ -25,6 +25,7 @@ MAX_TOKENS_IN_BATCH = 1024
 
 async def sequential_forward(
     inputs: torch.Tensor,
+    attention_mask: torch.Tensor,
     prompts: torch.Tensor,
     sequence_manager: RemoteSequenceManager,
     start_index: int = 0,
@@ -37,10 +38,12 @@ async def sequential_forward(
     """
 
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
+    assert isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 2, f"{type(attention_mask)}: {attention_mask.ndim}"
 
     inputs_device = inputs.device
     inputs_dtype = inputs.dtype
     inputs = inputs.cpu()
+    attention_mask = attention_mask.cpu()
     prompts = prompts.cpu()
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
@@ -68,7 +71,7 @@ async def sequential_forward(
                 span = sequences.popleft()
 
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                inputs_and_prompts = [inputs, prompts[span.start : span.end]]
+                inputs_and_prompts = [inputs, attention_mask, prompts[span.start : span.end]]
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
@@ -111,6 +114,7 @@ async def sequential_forward(
 async def sequential_backward(
     grad_outputs: Sequence[torch.Tensor],
     intermediate_inputs: List[torch.Tensor],
+    attention_mask: torch.Tensor,
     prompts: torch.Tensor,
     forward_sequences: List[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
@@ -128,6 +132,7 @@ async def sequential_backward(
 
     grad_outputs = [tensor.cpu() for tensor in grad_outputs]
     intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
+    attention_mask = attention_mask.cpu()
     prompts = prompts.cpu()
 
     grad_prompts_reversed = []
@@ -160,6 +165,7 @@ async def sequential_backward(
                     stub,
                     sequence_manager.rpc_info,
                     inputs,
+                    attention_mask,
                     grad_outputs,
                     prompts[span.start : span.end],
                     timeout=sequence_manager.request_timeout,
@@ -191,25 +197,25 @@ async def sequential_backward(
     return grad_outputs, grad_prompts
 
 
-async def _gather_forward(input_batches, prompt_batches, sequence_manager):
+async def _gather_forward(input_batches, attention_mask_batches, prompt_batches, sequence_manager):
     """Wrapper for asyncio.gather to perform parallel sequential forwards"""
     return await asyncio.gather(
         *[
-            sequential_forward(input_batch, prompt_batch, sequence_manager)
-            for input_batch, prompt_batch in zip(input_batches, prompt_batches)
+            sequential_forward(input_batch, attention_mask_batch, prompt_batch, sequence_manager)
+            for input_batch, attention_mask_batch, prompt_batch in zip(input_batches, attention_mask_batches, prompt_batches)
         ]
     )
 
 
 async def _gather_backward(
-    grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
+    grad_output_batches, intermediate_input_batches, attention_mask_batches, prompt_batches, forward_sequences, sequence_manager
 ):
     """Wrapper for asyncio.gather to perform parallel sequential backwards"""
     return await asyncio.gather(
         *[
-            sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
-            for grad_output, input_batch, prompt_batch, spans in zip(
-                grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
+            sequential_backward((grad_output,), input_batch, attention_mask_batch, prompt_batch, spans, sequence_manager)
+            for grad_output, input_batch, attention_mask_batch, prompt_batch, spans in zip(
+                grad_output_batches, intermediate_input_batches, attention_mask_batches, prompt_batches, forward_sequences
             )
         ]
     )
@@ -222,16 +228,17 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     """
 
     @staticmethod
-    def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
+    def forward(ctx, inputs: torch.Tensor, attention_mask: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
+        attention_mask_batches: Sequence[torch.Tensor] = attention_mask.detach().split(batch_size)
         if is_dummy(prompts):
             prompt_batches = [DUMMY] * len(input_batches)
         else:
             prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
 
         sequence_manager.rpc_info  # lazy init
-        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
+        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, attention_mask_batches, prompt_batches, sequence_manager))
         assert len(outputs) == len(input_batches)
 
         output_batches = [output[0] for output in outputs]
@@ -241,6 +248,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
         ctx.prompt_batches = prompt_batches
         ctx.sequence_manager = sequence_manager
         ctx.intemediate_input_batches = intemediate_input_batches
+        ctx.attention_mask_batches = attention_mask_batches
         ctx.sequences_for_batches = sequences_for_batches
         return torch.cat(output_batches, dim=0)
 
@@ -258,13 +266,14 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
             _gather_backward(
                 grad_output_batches,
                 intermediate_input_batches,
+                ctx.attention_mask_batches,
                 ctx.prompt_batches,
                 forward_sequences,
                 ctx.sequence_manager,
             )
         )
         grad_input_batches = [output[0][0] for output in outputs]
-        grad_prompt_batches = [output[1] for output in outputs]
+        grad_prompt_batches = [output[2] for output in outputs]
 
         grad_inputs = torch.cat(grad_input_batches, dim=0)
         dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]

+ 1 - 0
src/petals/server/backend.py

@@ -84,6 +84,7 @@ class TransformerBackend(ModuleBackend):
     def inference_step(
         self,
         hidden_states: torch.Tensor,
+        attention_masks: torch.Tensor,
         hypo_ids: torch.LongTensor,
         inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:

+ 8 - 5
src/petals/server/handler.py

@@ -405,15 +405,16 @@ async def _rpc_forward(
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
 
-    :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
+    :param flat_tensors: a list of tensors that includes first layer inputs, attention_mask, optional prompts and extra tensors
     :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
     :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
     """
-    hidden_states, prompts = flat_tensors
+    hidden_states, attention_masks, prompts = flat_tensors
     dtype = requested_backends[0].dtype
     # check parse input tensors and cast dtypes
     hidden_states = hidden_states.to(dtype)
+    attention_masks = attention_masks.to(dtype)
     assert hidden_states.ndim == 3
     if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
@@ -431,6 +432,7 @@ async def _rpc_forward(
         )
         (hidden_states,) = await backend.forward_pool.submit_task(
             hidden_states,
+            attention_masks,
             priority=priority,
         )
         assert isinstance(hidden_states, torch.Tensor)
@@ -447,9 +449,10 @@ async def _rpc_backward(
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
-    inputs, grad_outputs, prompts = flat_tensors
+    inputs, attention_masks, grad_outputs, prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)
+    attention_masks = attention_masks.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
 
     if prompts is None or is_dummy(prompts):
@@ -469,7 +472,7 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
             inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
         )
-        (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
+        (inputs,) = await backend.forward_pool.submit_task(inputs, attention_masks, priority=priority)
 
         assert isinstance(inputs, torch.Tensor)
 
@@ -485,7 +488,7 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
-        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, attention_masks, grad_outputs, priority=priority)
 
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):