瀏覽代碼

fix bug: do not raise "no adapter found" if ther is no adapter to be found

Your Name 2 年之前
父節點
當前提交
04cf3183ee
共有 1 個文件被更改,包括 5 次插入3 次删除
  1. 5 3
      src/petals/server/backend.py

+ 5 - 3
src/petals/server/backend.py

@@ -82,11 +82,13 @@ class TransformerBackend(ModuleBackend):
         return cache_tensors
 
     def forward(self, active_adapter: str, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
-        self.load_adapter_(active_adapter)  # empty string means remove any adapters
+        if active_adapter and not self.load_adapter_(active_adapter):
+            raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
         return super().forward(*inputs)
 
     def backward(self, active_adapter: str, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
-        self.load_adapter_(active_adapter)  # empty string means remove any adapters
+        if active_adapter and not self.load_adapter_(active_adapter):
+            raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
         return super().backward(*inputs)
 
     @torch.inference_mode()
@@ -98,7 +100,7 @@ class TransformerBackend(ModuleBackend):
     ) -> Tuple[torch.Tensor, ...]:
         assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
-        if not self.load_adapter_(inference_info.active_adapter):
+        if inference_info.active_adapter and not self.load_adapter_(inference_info.active_adapter):
             raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
         with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
             self._reorder_cache_inplace(cache_tensors, hypo_ids)