Browse Source

Fix errors in hivemind.p2p and hivemind.compression (#565)

This PR:

1. Fixes warnings in hivemind.p2p destructors.

2. Makes bfloat16 serialization in hivemind.compression forward- and backward-compatible. The code before this PR (a) didn't work in torch < 1.13.0 (hivemind requires torch >= 1.9.0) and (b) led to warnings on torch >= 2.0. The new code works without warnings in all versions of PyTorch.
Alexander Borzunov 2 năm trước cách đây
mục cha
commit
0d2614d7e8

+ 9 - 10
hivemind/compression/base.py

@@ -87,12 +87,11 @@ class NoCompression(CompressionBase):
         dtype_name = str(tensor.dtype).lstrip("torch.")
         raw_data = tensor
         if tensor.dtype == torch.bfloat16:
-            if USE_LEGACY_BFLOAT16:
+            if USE_LEGACY_BFLOAT16:  # legacy mode: convert to fp32
                 raw_data = tensor.to(torch.float32)
-            else:
-                typed_storage = tensor.storage()
-                storage = typed_storage.untyped() if hasattr(typed_storage, "untyped") else typed_storage._untyped()
-                raw_data = torch.tensor(storage, dtype=torch.int8)
+            else:  # efficient mode: send bfloat16 data directly
+                # reinterpret_cast to an arbitrary 2-byte type supported by numpy
+                raw_data = tensor.view(torch.int16)
 
         return runtime_pb2.Tensor(
             compression=self.compression_type,
@@ -106,13 +105,13 @@ class NoCompression(CompressionBase):
         shape = torch.Size(serialized_tensor.size)
         if serialized_tensor.dtype == "bfloat16":
             numel = shape.numel()
-            if numel > 0 and len(serialized_tensor.buffer) // numel == 4:  # legacy mode: convert to fp32
+            if numel > 0 and len(serialized_tensor.buffer) // numel == 4:
                 array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
                 tensor = torch.as_tensor(array, dtype=torch.bfloat16)
-            else:  # efficient mode: send bfloat16 data directly
-                storage_type = torch.TypedStorage if hasattr(torch, "TypedStorage") else torch._TypedStorage
-                storage = storage_type.from_buffer(serialized_tensor.buffer, byte_order="little", dtype=torch.bfloat16)
-                tensor = torch.as_tensor(storage, dtype=torch.bfloat16)
+            else:
+                array = np.frombuffer(serialized_tensor.buffer, dtype=np.int16)
+                # reinterpret_cast from an arbitrary 2-byte type supported by numpy
+                tensor = torch.as_tensor(array).view(torch.bfloat16)
         else:
             array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
             tensor = torch.as_tensor(array)

+ 3 - 2
hivemind/p2p/p2p_daemon.py

@@ -654,8 +654,9 @@ class P2P:
 
         self._alive = False
         if self._child is not None and self._child.returncode is None:
-            self._child.terminate()
-            logger.debug(f"Terminated p2pd with id = {self.peer_id}")
+            with suppress(ProcessLookupError):
+                self._child.terminate()
+                logger.debug(f"Terminated p2pd with id = {self.peer_id}")
 
             with suppress(FileNotFoundError):
                 os.remove(self._daemon_listen_maddr["unix"])

+ 2 - 1
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -47,7 +47,8 @@ class Client:
         return client
 
     def close(self) -> None:
-        self.control.close()
+        if self.control is not None:
+            self.control.close()
 
     def __del__(self):
         self.close()