Bladeren bron

Set number of threads to 1 (#65)

* Set number of threads to 1

* Add new benchmarks
Max Ryabinin 5 jaren geleden
bovenliggende
commit
989dd347c8
3 gewijzigde bestanden met toevoegingen van 10 en 0 verwijderingen
  1. 1 0
      hivemind/runtime/task_pool.py
  2. 2 0
      hivemind/server/__init__.py
  3. 7 0
      tests/benchmark_throughput.py

+ 1 - 0
hivemind/runtime/task_pool.py

@@ -125,6 +125,7 @@ class TaskPool(TaskPoolBase):
                 total_size += task_size
 
     def run(self, *args, **kwargs):
+        torch.set_num_threads(1)
         logger.info(f'{self.uid} starting, pid={os.getpid()}')
         pending_batches = {}  # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime
         output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches],

+ 2 - 0
hivemind/server/__init__.py

@@ -3,6 +3,7 @@ import os
 import threading
 from socket import socket, AF_INET, SOCK_STREAM, SO_REUSEADDR, SOL_SOCKET, timeout
 from typing import Dict, Optional
+import torch
 
 from .connection_handler import handle_connection
 from .dht_handler import DHTHandlerThread
@@ -121,6 +122,7 @@ class Server(threading.Thread):
 
 def socket_loop(sock, experts):
     """ catch connections, send tasks to processing, respond with results """
+    torch.set_num_threads(1)
     print(f'Spawned connection handler pid={os.getpid()}')
     while True:
         try:

+ 7 - 0
tests/benchmark_throughput.py

@@ -13,6 +13,7 @@ import hivemind
 
 
 def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
+    torch.set_num_threads(1)
     can_start.wait()
     experts = [hivemind.RemoteExpert(f"expert{i}", port=port) for i in range(num_experts)]
 
@@ -131,6 +132,12 @@ if __name__ == "__main__":
     elif args.preset == 'ffn_small_batch':
         benchmark_throughput(backprop=False, num_experts=4, batch_size=32, max_batch_size=8192,
                              num_batches_per_client=args.num_batches_per_client)
+    elif args.preset == 'ffn_small_batch_512clients':
+        benchmark_throughput(backprop=True, num_experts=1, batch_size=1, max_batch_size=8192,
+                             num_clients=512, num_batches_per_client=args.num_batches_per_client)
+    elif args.preset == 'ffn_small_batch_512clients_32handlers':
+        benchmark_throughput(backprop=True, num_experts=1, batch_size=1, max_batch_size=8192, num_handlers=32,
+                             num_clients=512, num_batches_per_client=args.num_batches_per_client)
     elif args.preset == 'ffn_massive':
         soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
         try: