|
@@ -15,7 +15,7 @@ from src import DistributedBloomConfig
|
|
|
from src.bloom.block import BloomBlock
|
|
|
from src.server.cache import MemoryCache
|
|
|
from src.server.backend import BloomBlockBackend
|
|
|
-from src.server.handler import BloomConnectionHandler
|
|
|
+from src.server.handler import TransformerConnectionHandler
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
@@ -31,7 +31,7 @@ class Server(threading.Thread):
|
|
|
):
|
|
|
threading.Thread.__init__(self)
|
|
|
self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
|
|
|
- self.conn_handlers = [BloomConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
|
|
|
+ self.conn_handlers = [TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
|
|
|
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
|
|
|
self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
|
|
|
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
|
@@ -105,15 +105,17 @@ class Server(threading.Thread):
|
|
|
blocks = {}
|
|
|
for i in range(num_blocks):
|
|
|
module_uid = f"dummy_block.{i}"
|
|
|
- HARDCODCED_LENGTH = 2048
|
|
|
+ block = BloomBlock(block_config, layer_number=i)
|
|
|
+ for param in block.parameters():
|
|
|
+ param.requires_grad = False
|
|
|
|
|
|
blocks[module_uid] = BloomBlockBackend(
|
|
|
module_uid,
|
|
|
- BloomBlock(block_config, layer_number=i),
|
|
|
+ block,
|
|
|
memory_cache=memory_cache,
|
|
|
- args_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
|
|
|
+ args_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
|
|
|
kwargs_schema={},
|
|
|
- outputs_schema=(BatchTensorDescriptor(1, HARDCODCED_LENGTH, block_config.hidden_size, compression=compression),),
|
|
|
+ outputs_schema=(BatchTensorDescriptor(1, 2048, block_config.hidden_size, compression=compression),),
|
|
|
min_batch_size=min_batch_size,
|
|
|
max_batch_size=max_batch_size,
|
|
|
)
|
|
@@ -121,7 +123,6 @@ class Server(threading.Thread):
|
|
|
return cls(
|
|
|
dht,
|
|
|
blocks,
|
|
|
- cache_size_bytes=cache_size_bytes,
|
|
|
num_connection_handlers=num_handlers,
|
|
|
device=device,
|
|
|
stats_report_interval=stats_report_interval,
|