瀏覽代碼

Fixes and test

artek0chumak 2 年之前
父節點
當前提交
d14debea35

+ 0 - 6
src/petals/client/inference_session.py

@@ -107,9 +107,6 @@ class _ServerInferenceSession:
         if attention_mask is None:
             attention_mask = DUMMY
 
-        if attention_mask is None:
-            attention_mask = DUMMY
-
         # serialize inputs and put them into the queue
         inputs = (new_hidden_states, attention_mask, prompts, hypo_ids)
         outputs_serialized = RemoteExpertWorker.run_coroutine(
@@ -239,9 +236,6 @@ class InferenceSession:
         if attention_mask is None:
             attention_mask = DUMMY
 
-        if attention_mask is None:
-            attention_mask = DUMMY
-
         inputs_device = inputs.device
         inputs_dtype = inputs.dtype
         inputs = inputs.cpu()

+ 0 - 3
src/petals/client/remote_model.py

@@ -192,9 +192,6 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
         hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
 
-        if attention_mask is None:
-            attention_mask = torch.ones((batch_size, hidden_states.size(1)), device=hidden_states.device)
-
         if attention_mask is None:
             attention_mask = torch.ones((batch_size, hidden_states.size(1)), device=hidden_states.device)
 

+ 1 - 1
src/petals/client/routing/sequence_manager.py

@@ -333,7 +333,7 @@ class _SequenceManagerUpdateThread(threading.Thread):
 
 def maybe_log_traceback(exc: Exception):
     traceback_level = logging.DEBUG if str(exc) or isinstance(exc, asyncio.TimeoutError) else logging.WARNING
-    logger.log(logging.INFO, "See detailed traceback below:", exc_info=True)
+    logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
 
 
 class MissingBlocksError(RuntimeError):

+ 1 - 1
src/petals/server/backend.py

@@ -84,7 +84,7 @@ class TransformerBackend(ModuleBackend):
     def inference_step(
         self,
         hidden_states: torch.Tensor,
-        attention_masks: torch.Tensor,
+        attention_mask: torch.Tensor,
         hypo_ids: torch.LongTensor,
         inference_info: InferenceMetadata,
     ) -> Tuple[torch.Tensor, ...]:

+ 8 - 8
tests/test_full_model.py

@@ -13,7 +13,8 @@ logger = get_logger(__file__)
 
 @pytest.mark.forked
 @pytest.mark.parametrize("pass_empty_tensors", (True, False))
-def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
+@pytest.mark.parametrize("second_token_attention_mask", (1, 0))
+def test_full_model_exact_match(pass_empty_tensors: bool, second_token_attention_mask: int, atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(
         MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
@@ -23,9 +24,11 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
     assert len(model.transformer.h) == model.config.n_layer
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+    attention_mask = torch.ones_like(test_inputs)
+    attention_mask[0, 1] = second_token_attention_mask
 
     with torch.inference_mode():
-        parallel_outputs = model.forward(test_inputs).logits
+        parallel_outputs = model.forward(test_inputs, attention_mask=attention_mask).logits
         assert torch.all(torch.isfinite(parallel_outputs))
         logger.info("Forward outputs are finite")
 
@@ -37,7 +40,7 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
                 recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
 
             for t in range(embs.shape[1]):
-                recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
+                recurrent_outputs.append(sess.step(embs[:, t : t + 1, :], attention_mask=attention_mask[:, :t+1]))
                 if t == int(embs.shape[1] // 2) and pass_empty_tensors:
                     recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
                     recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
@@ -58,13 +61,10 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
                 ref_model.resize_token_embeddings(config.vocab_size)
                 logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
 
-            dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
-            # note: this creates a dummy mask to make the test compatible with older transformer versions
-            # prior to https://github.com/huggingface/transformers/pull/17837
-            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
+            ref_outputs = ref_model.forward(test_inputs, attention_mask=attention_mask).logits.float()
             assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
             logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
-            del ref_model, ref_outputs, dummy_mask
+            del ref_model, ref_outputs
         else:
             logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
             assert False

+ 1 - 1
tests/test_remote_sequential.py

@@ -71,7 +71,7 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
         rpc_info = super().rpc_info
         dims = (2048, 1024)
         compressed_input_schema = BatchTensorDescriptor(dims, compression=runtime_pb2.CompressionType.FLOAT16)
-        rpc_info["forward_schema"] = (compressed_input_schema,), dict()  # (args, kwargs)
+        rpc_info["forward_schema"] = (compressed_input_schema, compressed_input_schema), dict()  # (args, kwargs)
         return rpc_info
 
     def get_request_metadata(self, protocol: str, *args, **kwargs):