Parcourir la source

Fix p2p pushing in rpc_inference (by @miaoqijun ) , support transformers 4.38.2 (#563)

This pull request solves #560 using a solution proposed by @miaoqijun .
It also bumps transformers to the latest version to test with the latest code.

---------

Co-authored-by: Yingtong Dou <ytongdou@gmail.com>
justheuristic il y a 1 an
Parent
commit
2ad0b2b936

+ 1 - 1
setup.cfg

@@ -37,7 +37,7 @@ install_requires =
     accelerate>=0.27.2
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3
-    transformers==4.37.1  # if you change this, please also change version assert in petals/__init__.py
+    transformers==4.38.2  # if you change this, please also change version assert in petals/__init__.py
     speedtest-cli==2.1.3
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
     hivemind==1.1.10.post2

+ 2 - 2
src/petals/__init__.py

@@ -22,8 +22,8 @@ __version__ = "2.3.0.dev2"
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
-        version.parse("4.37.1") <= version.parse(transformers.__version__) < version.parse("4.38.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.38.0"
+        version.parse("4.38.2") <= version.parse(transformers.__version__) < version.parse("4.39.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.39.0"
 
 
 def _override_bfloat16_mode_default():

+ 13 - 4
src/petals/models/llama/block.py

@@ -50,9 +50,15 @@ class OptimizedLlamaAttention(LlamaAttention):
         past_key_value: Optional[Tuple[torch.Tensor]] = None,
         output_attentions: bool = False,
         use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
         assert not output_attentions
-        assert position_ids is None
+        if position_ids is None:
+            past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0
+            position_ids = torch.arange(
+                past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
+            ).unsqueeze(0)
+
         bsz, q_len, _ = hidden_states.size()
 
         if self.config.pretraining_tp > 1:
@@ -84,9 +90,8 @@ class OptimizedLlamaAttention(LlamaAttention):
         kv_seq_len = key_states.shape[-2]
         if past_key_value is not None:
             kv_seq_len += past_key_value[0].shape[-2]
-        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-        cos = cos[kv_seq_len - q_len :]
-        sin = sin[kv_seq_len - q_len :]
+        cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
+        cos, sin = cos.unsqueeze(1), sin.unsqueeze(1)
 
         if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
             query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
@@ -160,6 +165,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
         past_key_value: Optional[Tuple[torch.Tensor]] = None,
         output_attentions: Optional[bool] = False,
         use_cache: Optional[bool] = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
     ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
         """
         Args:
@@ -190,6 +197,8 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
             past_key_value=past_key_value,
             output_attentions=output_attentions,
             use_cache=use_cache,
+            cache_position=cache_position,
+            **kwargs,
         )
 
         hidden_states = residual + hidden_states

+ 3 - 0
src/petals/models/llama/model.py

@@ -47,6 +47,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         output_attentions: Optional[bool] = None,
         output_hidden_states: Optional[bool] = None,
         return_dict: Optional[bool] = None,
+        cache_position: Optional[torch.LongTensor] = None,
     ) -> BaseModelOutputWithPast:
         if input_ids is not None and inputs_embeds is not None:
             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -62,6 +63,8 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         assert (
             attention_mask is None or (attention_mask == 1).all()
         ), f"Custom attention masks are not supported, {attention_mask=}"
+        if cache_position is not None:
+            assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()
         assert (
             position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
         ), f"Non-consecutive position_ids are not supported, {position_ids=}"

+ 2 - 2
src/petals/server/block_functions.py

@@ -153,7 +153,7 @@ async def iterate_rpc_inference(
     points: int,
     quant_type: QuantType,
     args_structure: Any = None,
-) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
+) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool, Dict]]:
     assert len(cache_handles) == len(requested_backends)
 
     prefix_length = 0
@@ -224,7 +224,7 @@ async def iterate_rpc_inference(
             for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
         ]
         can_push = not has_prompts
-        yield output_tensors, can_push
+        yield output_tensors, can_push, step_metadata
 
         # prepare for next step
         prefix_length += length_increment

+ 2 - 2
src/petals/server/handler.py

@@ -171,7 +171,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                     requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
                 ) as cache_handles:
                     background_tasks = set()
-                    async for output_tensors, can_push in iterate_rpc_inference(
+                    async for output_tensors, can_push, step_metadata in iterate_rpc_inference(
                         requested_uids=requested_uids,
                         requested_backends=requested_backends,
                         active_adapter=self._get_active_adapter(metadata),
@@ -186,7 +186,7 @@ class TransformerConnectionHandler(ConnectionHandler):
                         args_structure=args_structure,
                     ):
                         if can_push:
-                            task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
+                            task = asyncio.create_task(self._push_outputs(request, output_tensors[0], step_metadata))
                             background_tasks.add(task)  # Keep reference until it is done to save it from GC
                             task.add_done_callback(background_tasks.discard)
                         yield runtime_pb2.ExpertResponse(tensors=output_tensors)