Max Ryabinin 4 vuotta sitten
vanhempi
commit
553ace9353

+ 1 - 1
hivemind/hivemind_cli/run_server.py

@@ -80,7 +80,7 @@ def main():
     parser.add_argument('--averaging_min_refresh_period',type=float,default=1)
     parser.add_argument('--averaging_max_refresh_period',type=float,default=60)
     parser.add_argument('--averaging_default_refresh_period',type=float,default=10)
-    parser.add_argument('--averaging_expiration',type=float,default=30)
+    parser.add_argument('--averaging_expiration',type=float,default=10)
     parser.add_argument('--metadata_expiration',type=float,default=120)
     parser.add_argument('--averaging_timeout',type=float,default=30)
     parser.add_argument('--reuse_grad_buffers',type=bool,default=True)

+ 11 - 9
hivemind/moe/client/balanced_expert.py

@@ -118,15 +118,15 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
         ]
         while True:
-            # try:
-            with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
-                forward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
-                outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
-            break
-            # except KeyboardInterrupt:
-            #     break
-            # except BaseException as e:
-            #     logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
+            try:
+                with expert_balancer.use_another_expert(forward_task_size) as chosen_expert:
+                    forward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
+                    outputs = chosen_expert.stub.forward(forward_request, timeout=forward_timeout)
+                break
+            except KeyboardInterrupt:
+                raise
+            except BaseException as e:
+                logger.error(f"Tried to call forward for expert {chosen_expert} but caught {repr(e)}")
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
         return tuple(deserialized_outputs)
@@ -147,6 +147,8 @@ class _BalancedRemoteModuleCall(torch.autograd.Function):
                     backward_request = runtime_pb2.ExpertRequest(uid=chosen_expert.uid, tensors=serialized_tensors)
                     grad_inputs = chosen_expert.stub.backward(backward_request, timeout=ctx.backward_timeout)
                 break
+            except KeyboardInterrupt:
+                raise
             except BaseException as e:
                 logger.error(f"Tried to call backward for expert {chosen_expert} but caught {repr(e)}")
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]

+ 1 - 0
hivemind/moe/client/balancer.py

@@ -62,6 +62,7 @@ class ExpertBalancer:
                 )
             if len(self.queue) == 0:
                 logger.warning("Update routine finished, but still no experts available.")
+                time.sleep()
 
             self.last_update = get_dht_time()
             self.update_finished.set()

+ 2 - 2
hivemind/moe/server/__init__.py

@@ -179,11 +179,11 @@ class Server(threading.Thread):
             dht = None
         else:
             dht_port = dht_port or hivemind.get_free_port()
-            host_maddrs = [f"/ip4/0.0.0.0/tcp/{dht_port}"]
+            host_maddrs = []
             announce_maddrs = []
 
             if dht_listen_on is not None:
-                dht_maddr = f"/ip6/{dht_listen_on}/tcp/{dht_port}"
+                dht_maddr = f"/{dht_listen_on}/tcp/{dht_port}"
                 host_maddrs.append(dht_maddr)
                 announce_maddrs.append(dht_maddr)