|
@@ -35,10 +35,6 @@ class TransformerBackend(ModuleBackend):
|
|
|
max_chunk_size_bytes: int,
|
|
|
**kwargs,
|
|
|
):
|
|
|
- import petals.utils.peft as _peft_module
|
|
|
-
|
|
|
- self._peft_module = _peft_module
|
|
|
-
|
|
|
super().__init__(*args, **kwargs)
|
|
|
assert isinstance(self.module, TensorParallel)
|
|
|
self.config = config
|
|
@@ -98,13 +94,11 @@ class TransformerBackend(ModuleBackend):
|
|
|
|
|
|
def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
|
|
|
*inputs, active_adapter = inputs
|
|
|
- with self._peft_module.using_adapter(active_adapter):
|
|
|
- return super().forward(*inputs)
|
|
|
+ return super().forward(*inputs)
|
|
|
|
|
|
def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
|
|
|
*inputs, active_adapter = inputs
|
|
|
- with self._peft_module.using_adapter(active_adapter):
|
|
|
- return super().backward(*inputs)
|
|
|
+ return super().backward(*inputs)
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def inference_step(
|
|
@@ -116,9 +110,7 @@ class TransformerBackend(ModuleBackend):
|
|
|
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
|
|
seq_len = hidden_states.shape[1]
|
|
|
|
|
|
- with self.memory_cache.use_cache(
|
|
|
- *inference_info.cache_handles
|
|
|
- ) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
|
|
|
+ with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
|
|
|
self._reorder_cache_inplace(cache_tensors, hypo_ids)
|
|
|
|
|
|
# We chunk the inputs so that peak memory for long sequences fits into `autograd_memory`
|