artek0chumak 2 years ago
parent
commit
d2d0403a11

+ 3 - 4
src/petals/bloom/block.py

@@ -18,20 +18,19 @@ class WrappedBloomBlock(BloomBlock):
     def forward(
     def forward(
         self,
         self,
         hidden_states: torch.Tensor,
         hidden_states: torch.Tensor,
-        attention_mask: Optional[torch.Tensor] = None,
+        attention_mask: torch.Tensor,
         *args,
         *args,
         alibi: Optional[torch.Tensor] = None,
         alibi: Optional[torch.Tensor] = None,
         layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
         layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
         **kwargs
         **kwargs
     ):
     ):
-        assert attention_mask is None
         batch_size, seq_length = hidden_states.shape[:2]
         batch_size, seq_length = hidden_states.shape[:2]
         past_length = 0 if layer_past is None else layer_past[0].shape[-1]
         past_length = 0 if layer_past is None else layer_past[0].shape[-1]
         if alibi is None:
         if alibi is None:
             alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
             alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
-        causal_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
+        attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
         return super().forward(
         return super().forward(
-            hidden_states, *args, attention_mask=causal_mask, alibi=alibi, layer_past=layer_past, **kwargs
+            hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
         )
         )
 
 
     def _prepare_attn_mask(
     def _prepare_attn_mask(

+ 3 - 3
src/petals/client/remote_forward_backward.py

@@ -88,8 +88,8 @@ async def run_remote_forward(
     # Modify forward_schema to support prompts
     # Modify forward_schema to support prompts
     args_schema, kwargs_schema = rpc_info["forward_schema"]
     args_schema, kwargs_schema = rpc_info["forward_schema"]
     # TODO: rm this assert when support arbitrary number of input tensors
     # TODO: rm this assert when support arbitrary number of input tensors
-    assert len(args_schema) == 1 and len(inputs) == 2
-    forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
+    assert len(args_schema) == 2 and len(inputs) == 3
+    forward_schema_with_prompts = ((args_schema[0], args_schema[1], args_schema[0]), kwargs_schema)
 
 
     if not nested_compare(forward_inputs, forward_schema_with_prompts):
     if not nested_compare(forward_inputs, forward_schema_with_prompts):
         raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
         raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
@@ -135,7 +135,7 @@ async def run_remote_backward(
 
 
     # Modify forward_schema to support prompts
     # Modify forward_schema to support prompts
     args_schema, kwargs_schema = rpc_info["forward_schema"]
     args_schema, kwargs_schema = rpc_info["forward_schema"]
-    assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
+    assert len(args_schema) == 2 and isinstance(inputs, torch.Tensor)
     # TODO generalize this
     # TODO generalize this
     prompts_schema = next(iter(args_schema))
     prompts_schema = next(iter(args_schema))
     backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
     backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))

+ 4 - 4
src/petals/client/remote_model.py

@@ -184,16 +184,16 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
         if inputs_embeds is None:
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
             inputs_embeds = self.word_embeddings(input_ids)
 
 
+        batch_size = inputs_embeds.shape[0]
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
-            batch_size = inputs_embeds.shape[0]
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
 
-        if attention_mask is None:
-            attention_mask = torch.ones((batch_size, input_shape[-1]), device=hidden_states.device)
-
         hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
         output_shape = input_shape + (hidden_states.size(-1),)
+        
+        if attention_mask is None:
+            attention_mask = torch.ones((batch_size, hidden_states.size(1)), device=hidden_states.device)
 
 
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
             hidden_states = self.h(hidden_states, attention_mask, prompts=intermediate_prompts)
             hidden_states = self.h(hidden_states, attention_mask, prompts=intermediate_prompts)

+ 2 - 2
src/petals/client/remote_sequential.py

@@ -51,10 +51,10 @@ class RemoteSequential(nn.Module):
             assert isinstance(sequence_manager.sequence_info.block_uids, tuple)
             assert isinstance(sequence_manager.sequence_info.block_uids, tuple)
             self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids
             self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids
 
 
-    def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
+    def forward(self, inputs: torch.Tensor, attention_mask: torch.Tensor, prompts: torch.Tensor = DUMMY):
         assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
         assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
         assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version"
         assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version"
-        outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
+        outputs = _RemoteSequentialAutogradFunction.apply(inputs, attention_mask, prompts, self.sequence_manager)
         return outputs
         return outputs
 
 
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:

+ 20 - 11
src/petals/client/sequential_autograd.py

@@ -25,6 +25,7 @@ MAX_TOKENS_IN_BATCH = 1024
 
 
 async def sequential_forward(
 async def sequential_forward(
     inputs: torch.Tensor,
     inputs: torch.Tensor,
+    attention_mask: torch.Tensor,
     prompts: torch.Tensor,
     prompts: torch.Tensor,
     sequence_manager: RemoteSequenceManager,
     sequence_manager: RemoteSequenceManager,
     start_index: int = 0,
     start_index: int = 0,
@@ -37,10 +38,12 @@ async def sequential_forward(
     """
     """
 
 
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
+    assert isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 2, f"{type(attention_mask)}: {attention_mask.ndim}"
 
 
     inputs_device = inputs.device
     inputs_device = inputs.device
     inputs_dtype = inputs.dtype
     inputs_dtype = inputs.dtype
     inputs = inputs.cpu()
     inputs = inputs.cpu()
+    attention_mask = attention_mask.cpu()
     prompts = prompts.cpu()
     prompts = prompts.cpu()
 
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
@@ -68,7 +71,7 @@ async def sequential_forward(
                 span = sequences.popleft()
                 span = sequences.popleft()
 
 
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
                 stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
-                inputs_and_prompts = [inputs, prompts[span.start : span.end]]
+                inputs_and_prompts = [inputs, attention_mask, prompts[span.start : span.end]]
 
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
                 metadata = sequence_manager.get_request_metadata("rpc_forward", span_uids, *inputs_and_prompts)
@@ -111,6 +114,7 @@ async def sequential_forward(
 async def sequential_backward(
 async def sequential_backward(
     grad_outputs: Sequence[torch.Tensor],
     grad_outputs: Sequence[torch.Tensor],
     intermediate_inputs: List[torch.Tensor],
     intermediate_inputs: List[torch.Tensor],
+    attention_mask: torch.Tensor,
     prompts: torch.Tensor,
     prompts: torch.Tensor,
     forward_sequences: List[RemoteSpanInfo],
     forward_sequences: List[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
     sequence_manager: RemoteSequenceManager,
@@ -128,6 +132,7 @@ async def sequential_backward(
 
 
     grad_outputs = [tensor.cpu() for tensor in grad_outputs]
     grad_outputs = [tensor.cpu() for tensor in grad_outputs]
     intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
     intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
+    attention_mask = attention_mask.cpu()
     prompts = prompts.cpu()
     prompts = prompts.cpu()
 
 
     grad_prompts_reversed = []
     grad_prompts_reversed = []
@@ -160,6 +165,7 @@ async def sequential_backward(
                     stub,
                     stub,
                     sequence_manager.rpc_info,
                     sequence_manager.rpc_info,
                     inputs,
                     inputs,
+                    attention_mask,
                     grad_outputs,
                     grad_outputs,
                     prompts[span.start : span.end],
                     prompts[span.start : span.end],
                     timeout=sequence_manager.request_timeout,
                     timeout=sequence_manager.request_timeout,
@@ -191,25 +197,25 @@ async def sequential_backward(
     return grad_outputs, grad_prompts
     return grad_outputs, grad_prompts
 
 
 
 
-async def _gather_forward(input_batches, prompt_batches, sequence_manager):
+async def _gather_forward(input_batches, attention_mask_batches, prompt_batches, sequence_manager):
     """Wrapper for asyncio.gather to perform parallel sequential forwards"""
     """Wrapper for asyncio.gather to perform parallel sequential forwards"""
     return await asyncio.gather(
     return await asyncio.gather(
         *[
         *[
-            sequential_forward(input_batch, prompt_batch, sequence_manager)
-            for input_batch, prompt_batch in zip(input_batches, prompt_batches)
+            sequential_forward(input_batch, attention_mask_batch, prompt_batch, sequence_manager)
+            for input_batch, attention_mask_batch, prompt_batch in zip(input_batches, attention_mask_batches, prompt_batches)
         ]
         ]
     )
     )
 
 
 
 
 async def _gather_backward(
 async def _gather_backward(
-    grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
+    grad_output_batches, intermediate_input_batches, attention_mask_batches, prompt_batches, forward_sequences, sequence_manager
 ):
 ):
     """Wrapper for asyncio.gather to perform parallel sequential backwards"""
     """Wrapper for asyncio.gather to perform parallel sequential backwards"""
     return await asyncio.gather(
     return await asyncio.gather(
         *[
         *[
-            sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
-            for grad_output, input_batch, prompt_batch, spans in zip(
-                grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
+            sequential_backward((grad_output,), input_batch, attention_mask_batch, prompt_batch, spans, sequence_manager)
+            for grad_output, input_batch, attention_mask_batch, prompt_batch, spans in zip(
+                grad_output_batches, intermediate_input_batches, attention_mask_batches, prompt_batches, forward_sequences
             )
             )
         ]
         ]
     )
     )
@@ -222,16 +228,17 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     """
     """
 
 
     @staticmethod
     @staticmethod
-    def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
+    def forward(ctx, inputs: torch.Tensor, attention_mask: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
         input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
+        attention_mask_batches: Sequence[torch.Tensor] = attention_mask.detach().split(batch_size)
         if is_dummy(prompts):
         if is_dummy(prompts):
             prompt_batches = [DUMMY] * len(input_batches)
             prompt_batches = [DUMMY] * len(input_batches)
         else:
         else:
             prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
             prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
 
 
         sequence_manager.rpc_info  # lazy init
         sequence_manager.rpc_info  # lazy init
-        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
+        outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, attention_mask_batches, prompt_batches, sequence_manager))
         assert len(outputs) == len(input_batches)
         assert len(outputs) == len(input_batches)
 
 
         output_batches = [output[0] for output in outputs]
         output_batches = [output[0] for output in outputs]
@@ -241,6 +248,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
         ctx.prompt_batches = prompt_batches
         ctx.prompt_batches = prompt_batches
         ctx.sequence_manager = sequence_manager
         ctx.sequence_manager = sequence_manager
         ctx.intemediate_input_batches = intemediate_input_batches
         ctx.intemediate_input_batches = intemediate_input_batches
+        ctx.attention_mask_batches = attention_mask_batches
         ctx.sequences_for_batches = sequences_for_batches
         ctx.sequences_for_batches = sequences_for_batches
         return torch.cat(output_batches, dim=0)
         return torch.cat(output_batches, dim=0)
 
 
@@ -258,13 +266,14 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
             _gather_backward(
             _gather_backward(
                 grad_output_batches,
                 grad_output_batches,
                 intermediate_input_batches,
                 intermediate_input_batches,
+                ctx.attention_mask_batches,
                 ctx.prompt_batches,
                 ctx.prompt_batches,
                 forward_sequences,
                 forward_sequences,
                 ctx.sequence_manager,
                 ctx.sequence_manager,
             )
             )
         )
         )
         grad_input_batches = [output[0][0] for output in outputs]
         grad_input_batches = [output[0][0] for output in outputs]
-        grad_prompt_batches = [output[1] for output in outputs]
+        grad_prompt_batches = [output[2] for output in outputs]
 
 
         grad_inputs = torch.cat(grad_input_batches, dim=0)
         grad_inputs = torch.cat(grad_input_batches, dim=0)
         dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
         dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]

+ 1 - 0
src/petals/server/backend.py

@@ -84,6 +84,7 @@ class TransformerBackend(ModuleBackend):
     def inference_step(
     def inference_step(
         self,
         self,
         hidden_states: torch.Tensor,
         hidden_states: torch.Tensor,
+        attention_masks: torch.Tensor,
         hypo_ids: torch.LongTensor,
         hypo_ids: torch.LongTensor,
         inference_info: InferenceMetadata,
         inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:
     ) -> Tuple[torch.Tensor, ...]:

+ 8 - 5
src/petals/server/handler.py

@@ -405,15 +405,16 @@ async def _rpc_forward(
     """
     """
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
     Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
 
 
-    :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
+    :param flat_tensors: a list of tensors that includes first layer inputs, attention_mask, optional prompts and extra tensors
     :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
     :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
     :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
     :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
     :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
     """
     """
-    hidden_states, prompts = flat_tensors
+    hidden_states, attention_masks, prompts = flat_tensors
     dtype = requested_backends[0].dtype
     dtype = requested_backends[0].dtype
     # check parse input tensors and cast dtypes
     # check parse input tensors and cast dtypes
     hidden_states = hidden_states.to(dtype)
     hidden_states = hidden_states.to(dtype)
+    attention_masks = attention_masks.to(dtype)
     assert hidden_states.ndim == 3
     assert hidden_states.ndim == 3
     if prompts is None or is_dummy(prompts):
     if prompts is None or is_dummy(prompts):
         prompts = [DUMMY] * len(requested_backends)
         prompts = [DUMMY] * len(requested_backends)
@@ -431,6 +432,7 @@ async def _rpc_forward(
         )
         )
         (hidden_states,) = await backend.forward_pool.submit_task(
         (hidden_states,) = await backend.forward_pool.submit_task(
             hidden_states,
             hidden_states,
+            attention_masks,
             priority=priority,
             priority=priority,
         )
         )
         assert isinstance(hidden_states, torch.Tensor)
         assert isinstance(hidden_states, torch.Tensor)
@@ -447,9 +449,10 @@ async def _rpc_backward(
     prioritizer: TaskPrioritizerBase,
     prioritizer: TaskPrioritizerBase,
     points: int = 0,
     points: int = 0,
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
 ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
-    inputs, grad_outputs, prompts = flat_tensors
+    inputs, attention_masks, grad_outputs, prompts = flat_tensors
     # Cast inputs & grad outputs to backend dtype
     # Cast inputs & grad outputs to backend dtype
     inputs = inputs.to(requested_backends[0].dtype)
     inputs = inputs.to(requested_backends[0].dtype)
+    attention_masks = attention_masks.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
 
 
     if prompts is None or is_dummy(prompts):
     if prompts is None or is_dummy(prompts):
@@ -469,7 +472,7 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
         priority = prioritizer.prioritize(
             inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
             inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
         )
         )
-        (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
+        (inputs,) = await backend.forward_pool.submit_task(inputs, attention_masks, priority=priority)
 
 
         assert isinstance(inputs, torch.Tensor)
         assert isinstance(inputs, torch.Tensor)
 
 
@@ -485,7 +488,7 @@ async def _rpc_backward(
         priority = prioritizer.prioritize(
         priority = prioritizer.prioritize(
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
             inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
         )
         )
-        (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
+        (grad_outputs,) = await backend.backward_pool.submit_task(inp, attention_masks, grad_outputs, priority=priority)
 
 
         assert isinstance(grad_outputs, torch.Tensor)
         assert isinstance(grad_outputs, torch.Tensor)
         if not is_dummy(prompt):
         if not is_dummy(prompt):