Procházet zdrojové kódy

Fix TP crashing when hypo_ids are used (#249)

Alexander Borzunov před 2 roky
rodič
revize
3c523ab0d2
1 změnil soubory, kde provedl 1 přidání a 1 odebrání
  1. 1 1
      src/petals/server/backend.py

+ 1 - 1
src/petals/server/backend.py

@@ -99,7 +99,7 @@ class TransformerBackend(ModuleBackend):
         """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
         if not is_dummy(hypo_ids):
             for cache_tensor in cache_tensors:
-                cache_tensor[...] = cache_tensor[hypo_ids]  # in-place reorder cache by hypo ids
+                cache_tensor[...] = cache_tensor[hypo_ids.to(cache_tensor.device)]  # in-place reorder cache by hypo ids
 
     def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:
         """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""