فهرست منبع

Fix format bugs on client side

Aleksandr Borzunov 3 سال پیش
والد
کامیت
0336a04082
4فایلهای تغییر یافته به همراه20 افزوده شده و 8 حذف شده
  1. 2 2
      README.md
  2. 1 1
      src/client/remote_block.py
  3. 3 2
      src/dht_utils.py
  4. 14 3
      src/server/server.py

+ 2 - 2
README.md

@@ -43,7 +43,7 @@ python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6
 # - give each server a unique --identity_path (or remote --identity_path arg when debugging)
 # - give each server a unique --identity_path (or remote --identity_path arg when debugging)
 # - if running multiple servers on the same machine, give each a unique port (last integer in --host_maddrs, 0 means random port)
 # - if running multiple servers on the same machine, give each a unique port (last integer in --host_maddrs, 0 means random port)
 # - when running over the internet, change --host_maddrs according to https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet
 # - when running over the internet, change --host_maddrs according to https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet
-# - each server except first should have --initial_peers pointing to one of pre-existing servers 
+# - each server except first should have --initial_peers pointing to one of pre-existing servers
 ```
 ```
 
 
 Then open a python notebook or console and run:
 Then open a python notebook or console and run:
@@ -66,7 +66,7 @@ loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 loss.backward()
 
 
 # test inference, one block
 # test inference, one block
-with layer3.begin_inference_session() as sess:
+with layer3.inference_session() as sess:
     for i in range(10):
     for i in range(10):
         res = sess.step(torch.ones(1, 1, 4096))
         res = sess.step(torch.ones(1, 1, 4096))
 ```
 ```

+ 1 - 1
src/client/remote_block.py

@@ -25,7 +25,7 @@ class RemoteTransformerBlock(RemoteExpert):
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
 
 
     def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
     def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
-        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers)))  # TODO replace this
+        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys())))  # TODO replace this
         super().__init__(peer_info, p2p)
         super().__init__(peer_info, p2p)
 
 
     @property
     @property

+ 3 - 2
src/dht_utils.py

@@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Sequence, Union
 
 
 from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import P2P
+from hivemind.p2p import P2P, PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
 
 import src
 import src
@@ -133,13 +133,14 @@ async def _get_remote_module_infos(
         servers = {}
         servers = {}
         for peer_id, server_info in metadata.value.items():
         for peer_id, server_info in metadata.value.items():
             try:
             try:
+                peer_id = PeerID.from_base58(peer_id)
                 server_info = server_info.value
                 server_info = server_info.value
                 if not (isinstance(server_info, tuple) and len(server_info) == 2 and
                 if not (isinstance(server_info, tuple) and len(server_info) == 2 and
                         isinstance(server_info[0], int) and isinstance(server_info[1], float)):
                         isinstance(server_info[0], int) and isinstance(server_info[1], float)):
                     raise ValueError(f"Invalid server info for uid={uid}, peer_id={peer_id}: {server_info}")
                     raise ValueError(f"Invalid server info for uid={uid}, peer_id={peer_id}: {server_info}")
                 state, throughput = server_info
                 state, throughput = server_info
                 servers[peer_id] = ServerInfo(ServerState(state), throughput)
                 servers[peer_id] = ServerInfo(ServerState(state), throughput)
-            except ValueError as e:
+            except (TypeError, ValueError) as e:
                 logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
                 logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
         if servers:
         if servers:
             modules[i] = RemoteModuleInfo(uid, servers)
             modules[i] = RemoteModuleInfo(uid, servers)

+ 14 - 3
src/server/server.py

@@ -41,7 +41,8 @@ class Server(threading.Thread):
         **kwargs,
         **kwargs,
     ):
     ):
         threading.Thread.__init__(self)
         threading.Thread.__init__(self)
-        self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
+        self.dht, self.module_backends = dht, module_backends
+        self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
         self.conn_handlers = [
         self.conn_handlers = [
             TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
             TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
         ]
         ]
@@ -165,8 +166,8 @@ class Server(threading.Thread):
             state=ServerState.JOINING,
             state=ServerState.JOINING,
             throughput=throughput,
             throughput=throughput,
         )
         )
+        logger.info(f"Announced that blocks {block_indices} are joining")
 
 
-        logger.info(f"Loading blocks with indices {block_indices}")
         blocks = {}
         blocks = {}
         for module_uid, block_index in zip(module_uids, block_indices):
         for module_uid, block_index in zip(module_uids, block_indices):
             block = load_pretrained_block(
             block = load_pretrained_block(
@@ -232,6 +233,16 @@ class Server(threading.Thread):
         Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
         Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
         """
         """
+        if self.module_backends:
+            declare_active_modules(
+                self.dht,
+                self.module_backends.keys(),
+                expiration_time=get_dht_time() + self.expiration,
+                state=ServerState.OFFLINE,
+                throughput=self.throughput,
+            )
+            logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
+
         self.ready.clear()
         self.ready.clear()
 
 
         for process in self.conn_handlers:
         for process in self.conn_handlers:
@@ -253,7 +264,7 @@ class Server(threading.Thread):
         logger.debug(f"Shutting down runtime")
         logger.debug(f"Shutting down runtime")
 
 
         self.runtime.shutdown()
         self.runtime.shutdown()
-        logger.info("Server shutdown succesfully")
+        logger.info("Server shut down succesfully")
 
 
 
 
 class ModuleAnnouncerThread(threading.Thread):
 class ModuleAnnouncerThread(threading.Thread):