|
@@ -5,7 +5,6 @@ import sys
|
|
|
import time
|
|
|
|
|
|
import torch
|
|
|
-from test_utils import print_device_info
|
|
|
|
|
|
import hivemind
|
|
|
from hivemind import find_open_port
|
|
@@ -13,6 +12,19 @@ from hivemind.server import layers
|
|
|
from hivemind.utils.threading import increase_file_limit
|
|
|
|
|
|
|
|
|
+def print_device_info(device=None):
|
|
|
+ """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
|
|
|
+ device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
|
|
|
+ print('Using device:', device)
|
|
|
+
|
|
|
+ # Additional Info when using cuda
|
|
|
+ if device.type == 'cuda':
|
|
|
+ print(torch.cuda.get_device_name(0))
|
|
|
+ print('Memory Usage:')
|
|
|
+ print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
|
|
|
+ print('Cached: ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')
|
|
|
+
|
|
|
+
|
|
|
def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
|
|
|
torch.set_num_threads(1)
|
|
|
can_start.wait()
|
|
@@ -65,7 +77,8 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
|
|
|
for i in range(num_experts):
|
|
|
expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
|
|
|
experts[f'expert{i}'] = hivemind.ExpertBackend(name=f'expert{i}',
|
|
|
- expert=expert, optimizer=torch.optim.Adam(expert.parameters()),
|
|
|
+ expert=expert,
|
|
|
+ optimizer=torch.optim.Adam(expert.parameters()),
|
|
|
args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
|
|
|
outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
|
|
|
max_batch_size=max_batch_size,
|