|
@@ -65,8 +65,10 @@ class DistributedBloomModel(BloomModel):
|
|
|
|
|
|
def forward(self, *args, use_cache=None, **kwargs):
|
|
|
if use_cache:
|
|
|
- raise ValueError("Distributed forward does not support use_cache; for efficient cache-aware generation, "
|
|
|
- "please use model.transformer.inference_session() or model.generate(...)")
|
|
|
+ raise ValueError(
|
|
|
+ "Distributed forward does not support use_cache; for efficient cache-aware generation, "
|
|
|
+ "please use model.transformer.inference_session() or model.generate(...)"
|
|
|
+ )
|
|
|
return super().forward(*args, use_cache=False, **kwargs)
|
|
|
|
|
|
|
|
@@ -136,7 +138,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
|
|
|
|
|
|
|
|
|
class DistributedBloomForCausalLM(BloomForCausalLM):
|
|
|
- """ Similar to BloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
+ """Similar to BloomForCausalLM, but all transformer layers are hosted by the swarm"""
|
|
|
|
|
|
config_class = DistributedBloomConfig
|
|
|
|