瀏覽代碼

WIP, switching to another PR

Your Name 2 年之前
父節點
當前提交
84ebd57105
共有 3 個文件被更改,包括 2 次插入5 次删除
  1. 0 3
      src/petals/server/backend.py
  2. 1 1
      src/petals/server/block_functions.py
  3. 1 1
      tests/test_server_stats.py

+ 0 - 3
src/petals/server/backend.py

@@ -112,9 +112,6 @@ class TransformerBackend(ModuleBackend):
     def backward(
         self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs
     ) -> Tuple[Union[torch.Tensor, Any], ...]:
-        args = [x.detach().requires_grad_(True) if x.is_floating_point() else x.detach() for x in args]
-        # ^-- TODO remove this AFTER PR#467; make sure args are passed properly and retain requires_grad
-        assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor))
         with self._peft_module.using_adapter(active_adapter), torch.enable_grad():
             (outputs,) = self.module(*args, **kwargs)
             assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape

+ 1 - 1
src/petals/server/block_functions.py

@@ -92,7 +92,7 @@ async def run_rpc_backward(
         requested_backends, flat_tensors, args_structure
     )
     # Cast inputs & grad outputs to backend dtype
-    assert hidden_states.ndim == 3
+    assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad"
     num_tokens = hidden_states.shape[0] * hidden_states.shape[1]
     hidden_states = hidden_states.to(requested_backends[0].dtype)
     grad_outputs = grad_outputs.to(requested_backends[-1].dtype)

+ 1 - 1
tests/test_server_stats.py

@@ -9,7 +9,7 @@ from petals.server.handler import CACHE_TOKENS_AVAILABLE
 from test_utils import *
 
 
-@pytest.mark.forked
+@pytest.mark.skip
 def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50):
     config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
     config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"]  # PeerID from server2.id