瀏覽代碼

Fix TP crashing when hypo_ids are used (#249)

Alexander Borzunov 2 年之前
父節點
當前提交
3c523ab0d2
共有 1 個文件被更改,包括 1 次插入1 次删除
  1. 1 1
      src/petals/server/backend.py

+ 1 - 1
src/petals/server/backend.py

@@ -99,7 +99,7 @@ class TransformerBackend(ModuleBackend):
         """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
         if not is_dummy(hypo_ids):
             for cache_tensor in cache_tensors:
-                cache_tensor[...] = cache_tensor[hypo_ids]  # in-place reorder cache by hypo ids
+                cache_tensor[...] = cache_tensor[hypo_ids.to(cache_tensor.device)]  # in-place reorder cache by hypo ids
 
     def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:
         """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""