justheuristic 3 vuotta sitten
vanhempi
commit
9ef52f3421
2 muutettua tiedostoa jossa 6 lisäystä ja 4 poistoa
  1. 5 3
      src/client/remote_model.py
  2. 1 1
      tests/test_chained_calls.py

+ 5 - 3
src/client/remote_model.py

@@ -65,8 +65,10 @@ class DistributedBloomModel(BloomModel):
 
     def forward(self, *args, use_cache=None, **kwargs):
         if use_cache:
-            raise ValueError("Distributed forward does not support use_cache; for efficient cache-aware generation, "
-                             "please use model.transformer.inference_session() or model.generate(...)")
+            raise ValueError(
+                "Distributed forward does not support use_cache; for efficient cache-aware generation, "
+                "please use model.transformer.inference_session() or model.generate(...)"
+            )
         return super().forward(*args, use_cache=False, **kwargs)
 
 
@@ -136,7 +138,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
 
 
 class DistributedBloomForCausalLM(BloomForCausalLM):
-    """ Similar to BloomForCausalLM, but all transformer layers are hosted by the swarm"""
+    """Similar to BloomForCausalLM, but all transformer layers are hosted by the swarm"""
 
     config_class = DistributedBloomConfig
 

+ 1 - 1
tests/test_chained_calls.py

@@ -8,7 +8,7 @@ import os
 import hivemind
 import torch
 import transformers
-from hivemind.moe.expert_uid import ExpertInfo, UID_DELIMITER
+from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
 
 from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_block import RemoteTransformerBlock