|
@@ -1,14 +1,17 @@
|
|
|
import asyncio
|
|
|
+import gc
|
|
|
import random
|
|
|
import time
|
|
|
+from contextlib import suppress
|
|
|
|
|
|
import numpy as np
|
|
|
+import psutil
|
|
|
import pytest
|
|
|
import torch
|
|
|
|
|
|
import hivemind
|
|
|
import hivemind.averaging.averager
|
|
|
-from conftest import cleanup_children
|
|
|
+from hivemind import MPFuture, get_logger
|
|
|
from hivemind.averaging.allreduce import AveragingMode
|
|
|
from hivemind.averaging.key_manager import GroupKeyManager
|
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
@@ -16,6 +19,29 @@ from hivemind.p2p import PeerID
|
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
from test_utils.dht_swarms import launch_dht_instances
|
|
|
|
|
|
+logger = get_logger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+def cleanup_children():
|
|
|
+ yield
|
|
|
+
|
|
|
+ gc.collect() # Call .__del__() for removed objects
|
|
|
+
|
|
|
+ children = psutil.Process().children(recursive=True)
|
|
|
+ if children:
|
|
|
+ logger.info(f"Cleaning up {len(children)} leftover child processes")
|
|
|
+ for child in children:
|
|
|
+ with suppress(psutil.NoSuchProcess):
|
|
|
+ child.terminate()
|
|
|
+ psutil.wait_procs(children, timeout=1)
|
|
|
+ for child in children:
|
|
|
+ with suppress(psutil.NoSuchProcess):
|
|
|
+ child.kill()
|
|
|
+
|
|
|
+ # Broken code or killing of child processes may leave the MPFuture backend corrupted
|
|
|
+ MPFuture.reset_backend()
|
|
|
+
|
|
|
+
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
@pytest.mark.asyncio
|