|
@@ -82,11 +82,13 @@ class TransformerBackend(ModuleBackend):
|
|
|
return cache_tensors
|
|
|
|
|
|
def forward(self, active_adapter: str, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
|
- self.load_adapter_(active_adapter) # empty string means remove any adapters
|
|
|
+ if active_adapter and not self.load_adapter_(active_adapter):
|
|
|
+ raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
|
|
|
return super().forward(*inputs)
|
|
|
|
|
|
def backward(self, active_adapter: str, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
|
- self.load_adapter_(active_adapter) # empty string means remove any adapters
|
|
|
+ if active_adapter and not self.load_adapter_(active_adapter):
|
|
|
+ raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
|
|
|
return super().backward(*inputs)
|
|
|
|
|
|
@torch.inference_mode()
|
|
@@ -98,7 +100,7 @@ class TransformerBackend(ModuleBackend):
|
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
|
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
|
|
|
|
|
- if not self.load_adapter_(inference_info.active_adapter):
|
|
|
+ if inference_info.active_adapter and not self.load_adapter_(inference_info.active_adapter):
|
|
|
raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
|
|
|
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
|
|
|
self._reorder_cache_inplace(cache_tensors, hypo_ids)
|