Selaa lähdekoodia

Merge remote-tracking branch 'origin/main' into server-dtypes

justheuristic 2 vuotta sitten
vanhempi
commit
260bd70f96

+ 20 - 7
src/bloom/model.py

@@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
 Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
 See commit history for authorship.
 """
-from typing import Tuple, Union
+from typing import Optional, Tuple, Union
 
 import torch
 import torch.nn.functional as F
@@ -108,11 +108,24 @@ BLOOM_INPUTS_DOCSTRING = r"""
 """
 
 
+class _BloomPreTrainedModelWithModifiedDefaults(BloomPreTrainedModel):
+    @classmethod
+    def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
+        if low_cpu_mem_usage is None:
+            low_cpu_mem_usage = True
+        return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
+
+    from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
+        "low_cpu_mem_usage(`bool`, *optional*)",
+        "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
+    )
+
+
 @add_start_docstrings(
     "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
     BLOOM_START_DOCSTRING,
 )
-class BloomModel(BloomPreTrainedModel):
+class BloomModel(_BloomPreTrainedModelWithModifiedDefaults):
     def __init__(self, config):
         super().__init__(config)
         assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
@@ -277,7 +290,7 @@ class BloomModel(BloomPreTrainedModel):
     """,
     BLOOM_START_DOCSTRING,
 )
-class BloomForCausalLM(BloomPreTrainedModel):
+class BloomForCausalLM(_BloomPreTrainedModelWithModifiedDefaults):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
     def __init__(self, config):
@@ -400,8 +413,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
 @add_start_docstrings(
     """
     The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
-    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. 
-    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.  
+    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
+    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
     """,
     BLOOM_START_DOCSTRING,
 )
@@ -436,7 +449,7 @@ class LMHead(nn.Module):
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
             hidden_states = hidden_states.to(word_embeddings.dtype)
-            lm_logits = F.linear(hidden_states, word_embeddings).float()
+            lm_logits = F.linear(hidden_states, word_embeddings)
         return lm_logits
 
     def chunked_forward(self, hidden_states):
@@ -470,7 +483,7 @@ class LMHead(nn.Module):
     """,
     BLOOM_START_DOCSTRING,
 )
-class BloomForSequenceClassification(BloomPreTrainedModel):
+class BloomForSequenceClassification(_BloomPreTrainedModelWithModifiedDefaults):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
     def __init__(self, config):

+ 11 - 3
src/client/inference_session.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 import asyncio
 import itertools
+import logging
 import time
 from typing import AsyncIterator, List, Optional
 
@@ -18,7 +19,6 @@ from hivemind import (
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils.asyncio import aiter_with_timeout
 
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
@@ -218,6 +218,11 @@ class InferenceSession:
         else:
             assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
 
+        inputs_device = inputs.device
+        inputs_dtype = inputs.dtype
+        inputs = inputs.cpu()
+        prompts = prompts.cpu()
+
         n_input_tokens = inputs.shape[1]
         if self._position + n_input_tokens > self._max_length:
             raise ValueError(
@@ -300,11 +305,14 @@ class InferenceSession:
                         f"Caught exception when running inference from block {block_idx} "
                         f"(retry in {delay:.0f} sec): {repr(e)}"
                     )
-                    logger.debug("See detailed traceback below:", exc_info=True)
+                    traceback_level = logging.DEBUG if str(e) else logging.WARNING
+                    logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                     time.sleep(delay)
 
         self._position += n_input_tokens
-        return inputs
+
+        outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
+        return outputs
 
     def close(self, *exc_details):
         """Finish a given inference session, close the underlying connection"""

+ 1 - 1
src/client/remote_model.py

@@ -129,7 +129,7 @@ class DistributedBloomModel(BloomModel):
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
 
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:

+ 0 - 1
src/client/remote_sequential.py

@@ -31,7 +31,6 @@ class RemoteSequential(nn.Module):
         p2p: Optional[P2P] = None,
         sequence_manager: Optional[RemoteSequenceManager] = None,
     ):
-        logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
         super().__init__()
         self.config = config
         self.dht = dht

+ 1 - 1
src/client/sequence_manager.py

@@ -30,7 +30,7 @@ class RemoteSequenceManager:
         block_uids: Sequence[ModuleUID],
         p2p: P2P,
         max_retries: int = 3,
-        timeout: float = 5,
+        timeout: float = 20,
         min_backoff: float = 1,
     ):
         assert len(block_uids) > 0, "Sequences must contain at least one block"

+ 27 - 3
src/client/sequential_autograd.py

@@ -3,6 +3,7 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
 """
 import asyncio
 import itertools
+import logging
 from collections import deque
 from typing import List, Optional, Sequence, Tuple
 
@@ -36,6 +37,11 @@ async def sequential_forward(
 
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
 
+    inputs_device = inputs.device
+    inputs_dtype = inputs.dtype
+    inputs = inputs.cpu()
+    prompts = prompts.cpu()
+
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
     assert is_dummy(prompts) or len(prompts) == len(
@@ -86,9 +92,12 @@ async def sequential_forward(
                     f"Caught exception when running forward from block {block_idx} "
                     f"(retry in {delay:.0f} sec): {repr(e)}"
                 )
-                logger.debug("See detailed traceback below:", exc_info=True)
+                traceback_level = logging.DEBUG if str(e) else logging.WARNING
+                logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                 await asyncio.sleep(delay)
 
+    outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
+    intermediate_inputs = [tensor.to(device=inputs_device, dtype=inputs_dtype) for tensor in intermediate_inputs]
     return outputs, intermediate_inputs, done_sequences
 
 
@@ -98,13 +107,22 @@ async def sequential_backward(
     prompts: torch.Tensor,
     forward_sequences: List[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
-) -> Sequence[torch.Tensor]:
+) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
     """
     Performs chained backward for each forward subsequence.
     If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
     """
     assert len(intermediate_inputs) == len(forward_sequences)
 
+    grad_outputs_device = grad_outputs[0].device if grad_outputs else None
+    grad_outputs_dtype = grad_outputs[0].dtype if grad_outputs else None
+    prompts_device = prompts.device
+    prompts_dtype = prompts.dtype
+
+    grad_outputs = [tensor.cpu() for tensor in grad_outputs]
+    intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
+    prompts = prompts.cpu()
+
     grad_prompts_reversed = []
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
         inputs = intermediate_inputs.pop()
@@ -146,12 +164,18 @@ async def sequential_backward(
                     f"Caught exception when running backward between blocks {span.start}-{span.end} "
                     f"(retry in {delay:.0f} sec): {repr(e)}"
                 )
-                logger.debug("See detailed traceback below:", exc_info=True)
+                traceback_level = logging.DEBUG if str(e) else logging.WARNING
+                logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                 await asyncio.sleep(delay)
 
     # For now, we do not support mixed dummy and grad prompts
     # Concat in num_layer dimension
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
+
+    if grad_outputs_dtype is not None:
+        grad_outputs = [tensor.to(device=grad_outputs_device, dtype=grad_outputs_dtype) for tensor in grad_outputs]
+    if grad_prompts is not None:
+        grad_prompts = grad_prompts.to(device=prompts_device, dtype=prompts_dtype)
     return grad_outputs, grad_prompts