justheuristic 3 years ago
parent
commit
1c68670d06
1 changed files with 5 additions and 5 deletions
  1. 5 5
      src/client/remote_model.py

+ 5 - 5
src/client/remote_model.py

@@ -27,12 +27,11 @@ class DistributedBloomConfig(BloomConfig):
 
 class DistributedBloomModel(BloomModel):
     """BloomModel, but all transformer layers are hosted by the swarm"""
+    config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):
-        assert self.config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
-        assert (
-            self.config.initial_peers or config.dht
-        ), "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
+        assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
+        assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
 
         n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
         super().__init__(config)
@@ -50,9 +49,10 @@ class DistributedBloomModel(BloomModel):
 
 class DistributedBloomForCausalLM(BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+    config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):
-        BloomPreTrainedModel().__init__(config)
+        BloomPreTrainedModel.__init__(self, config)
         self.transformer = DistributedBloomModel(config)
         # Initialize weights and apply final processing
         self.post_init()