浏览代码

Merge remote-tracking branch 'origin/main' into server-dtypes

justheuristic 2 年之前
父节点
当前提交
260bd70f96

+ 20 - 7
src/bloom/model.py

@@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
 Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
 Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
 See commit history for authorship.
 See commit history for authorship.
 """
 """
-from typing import Tuple, Union
+from typing import Optional, Tuple, Union
 
 
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
@@ -108,11 +108,24 @@ BLOOM_INPUTS_DOCSTRING = r"""
 """
 """
 
 
 
 
+class _BloomPreTrainedModelWithModifiedDefaults(BloomPreTrainedModel):
+    @classmethod
+    def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
+        if low_cpu_mem_usage is None:
+            low_cpu_mem_usage = True
+        return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
+
+    from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
+        "low_cpu_mem_usage(`bool`, *optional*)",
+        "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
+    )
+
+
 @add_start_docstrings(
 @add_start_docstrings(
     "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
     "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
     BLOOM_START_DOCSTRING,
     BLOOM_START_DOCSTRING,
 )
 )
-class BloomModel(BloomPreTrainedModel):
+class BloomModel(_BloomPreTrainedModelWithModifiedDefaults):
     def __init__(self, config):
     def __init__(self, config):
         super().__init__(config)
         super().__init__(config)
         assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
         assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
@@ -277,7 +290,7 @@ class BloomModel(BloomPreTrainedModel):
     """,
     """,
     BLOOM_START_DOCSTRING,
     BLOOM_START_DOCSTRING,
 )
 )
-class BloomForCausalLM(BloomPreTrainedModel):
+class BloomForCausalLM(_BloomPreTrainedModelWithModifiedDefaults):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
     def __init__(self, config):
     def __init__(self, config):
@@ -400,8 +413,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
 @add_start_docstrings(
 @add_start_docstrings(
     """
     """
     The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
     The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
-    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries. 
-    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.  
+    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
+    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
     """,
     """,
     BLOOM_START_DOCSTRING,
     BLOOM_START_DOCSTRING,
 )
 )
@@ -436,7 +449,7 @@ class LMHead(nn.Module):
         else:
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
             # Switch dtype in case word_embeddings are fp16/bf16
             hidden_states = hidden_states.to(word_embeddings.dtype)
             hidden_states = hidden_states.to(word_embeddings.dtype)
-            lm_logits = F.linear(hidden_states, word_embeddings).float()
+            lm_logits = F.linear(hidden_states, word_embeddings)
         return lm_logits
         return lm_logits
 
 
     def chunked_forward(self, hidden_states):
     def chunked_forward(self, hidden_states):
@@ -470,7 +483,7 @@ class LMHead(nn.Module):
     """,
     """,
     BLOOM_START_DOCSTRING,
     BLOOM_START_DOCSTRING,
 )
 )
-class BloomForSequenceClassification(BloomPreTrainedModel):
+class BloomForSequenceClassification(_BloomPreTrainedModelWithModifiedDefaults):
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
     _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
 
 
     def __init__(self, config):
     def __init__(self, config):

+ 11 - 3
src/client/inference_session.py

@@ -2,6 +2,7 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import itertools
 import itertools
+import logging
 import time
 import time
 from typing import AsyncIterator, List, Optional
 from typing import AsyncIterator, List, Optional
 
 
@@ -18,7 +19,6 @@ from hivemind import (
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.p2p import StubBase
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
-from hivemind.utils.asyncio import aiter_with_timeout
 
 
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
@@ -218,6 +218,11 @@ class InferenceSession:
         else:
         else:
             assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
             assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
 
 
+        inputs_device = inputs.device
+        inputs_dtype = inputs.dtype
+        inputs = inputs.cpu()
+        prompts = prompts.cpu()
+
         n_input_tokens = inputs.shape[1]
         n_input_tokens = inputs.shape[1]
         if self._position + n_input_tokens > self._max_length:
         if self._position + n_input_tokens > self._max_length:
             raise ValueError(
             raise ValueError(
@@ -300,11 +305,14 @@ class InferenceSession:
                         f"Caught exception when running inference from block {block_idx} "
                         f"Caught exception when running inference from block {block_idx} "
                         f"(retry in {delay:.0f} sec): {repr(e)}"
                         f"(retry in {delay:.0f} sec): {repr(e)}"
                     )
                     )
-                    logger.debug("See detailed traceback below:", exc_info=True)
+                    traceback_level = logging.DEBUG if str(e) else logging.WARNING
+                    logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                     time.sleep(delay)
                     time.sleep(delay)
 
 
         self._position += n_input_tokens
         self._position += n_input_tokens
-        return inputs
+
+        outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
+        return outputs
 
 
     def close(self, *exc_details):
     def close(self, *exc_details):
         """Finish a given inference session, close the underlying connection"""
         """Finish a given inference session, close the underlying connection"""

+ 1 - 1
src/client/remote_model.py

@@ -129,7 +129,7 @@ class DistributedBloomModel(BloomModel):
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
 
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
         output_shape = input_shape + (hidden_states.size(-1),)
 
 
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:

+ 0 - 1
src/client/remote_sequential.py

@@ -31,7 +31,6 @@ class RemoteSequential(nn.Module):
         p2p: Optional[P2P] = None,
         p2p: Optional[P2P] = None,
         sequence_manager: Optional[RemoteSequenceManager] = None,
         sequence_manager: Optional[RemoteSequenceManager] = None,
     ):
     ):
-        logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
         self.dht = dht
         self.dht = dht

+ 1 - 1
src/client/sequence_manager.py

@@ -30,7 +30,7 @@ class RemoteSequenceManager:
         block_uids: Sequence[ModuleUID],
         block_uids: Sequence[ModuleUID],
         p2p: P2P,
         p2p: P2P,
         max_retries: int = 3,
         max_retries: int = 3,
-        timeout: float = 5,
+        timeout: float = 20,
         min_backoff: float = 1,
         min_backoff: float = 1,
     ):
     ):
         assert len(block_uids) > 0, "Sequences must contain at least one block"
         assert len(block_uids) > 0, "Sequences must contain at least one block"

+ 27 - 3
src/client/sequential_autograd.py

@@ -3,6 +3,7 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
 """
 """
 import asyncio
 import asyncio
 import itertools
 import itertools
+import logging
 from collections import deque
 from collections import deque
 from typing import List, Optional, Sequence, Tuple
 from typing import List, Optional, Sequence, Tuple
 
 
@@ -36,6 +37,11 @@ async def sequential_forward(
 
 
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
     assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
 
 
+    inputs_device = inputs.device
+    inputs_dtype = inputs.dtype
+    inputs = inputs.cpu()
+    prompts = prompts.cpu()
+
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
     assert is_dummy(prompts) or len(prompts) == len(
     assert is_dummy(prompts) or len(prompts) == len(
@@ -86,9 +92,12 @@ async def sequential_forward(
                     f"Caught exception when running forward from block {block_idx} "
                     f"Caught exception when running forward from block {block_idx} "
                     f"(retry in {delay:.0f} sec): {repr(e)}"
                     f"(retry in {delay:.0f} sec): {repr(e)}"
                 )
                 )
-                logger.debug("See detailed traceback below:", exc_info=True)
+                traceback_level = logging.DEBUG if str(e) else logging.WARNING
+                logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                 await asyncio.sleep(delay)
                 await asyncio.sleep(delay)
 
 
+    outputs = inputs.to(device=inputs_device, dtype=inputs_dtype)
+    intermediate_inputs = [tensor.to(device=inputs_device, dtype=inputs_dtype) for tensor in intermediate_inputs]
     return outputs, intermediate_inputs, done_sequences
     return outputs, intermediate_inputs, done_sequences
 
 
 
 
@@ -98,13 +107,22 @@ async def sequential_backward(
     prompts: torch.Tensor,
     prompts: torch.Tensor,
     forward_sequences: List[RemoteSpanInfo],
     forward_sequences: List[RemoteSpanInfo],
     sequence_manager: RemoteSequenceManager,
     sequence_manager: RemoteSequenceManager,
-) -> Sequence[torch.Tensor]:
+) -> Tuple[Sequence[torch.Tensor], torch.Tensor]:
     """
     """
     Performs chained backward for each forward subsequence.
     Performs chained backward for each forward subsequence.
     If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
     If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
     """
     """
     assert len(intermediate_inputs) == len(forward_sequences)
     assert len(intermediate_inputs) == len(forward_sequences)
 
 
+    grad_outputs_device = grad_outputs[0].device if grad_outputs else None
+    grad_outputs_dtype = grad_outputs[0].dtype if grad_outputs else None
+    prompts_device = prompts.device
+    prompts_dtype = prompts.dtype
+
+    grad_outputs = [tensor.cpu() for tensor in grad_outputs]
+    intermediate_inputs = [tensor.cpu() for tensor in intermediate_inputs]
+    prompts = prompts.cpu()
+
     grad_prompts_reversed = []
     grad_prompts_reversed = []
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
     while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
         inputs = intermediate_inputs.pop()
         inputs = intermediate_inputs.pop()
@@ -146,12 +164,18 @@ async def sequential_backward(
                     f"Caught exception when running backward between blocks {span.start}-{span.end} "
                     f"Caught exception when running backward between blocks {span.start}-{span.end} "
                     f"(retry in {delay:.0f} sec): {repr(e)}"
                     f"(retry in {delay:.0f} sec): {repr(e)}"
                 )
                 )
-                logger.debug("See detailed traceback below:", exc_info=True)
+                traceback_level = logging.DEBUG if str(e) else logging.WARNING
+                logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
                 await asyncio.sleep(delay)
                 await asyncio.sleep(delay)
 
 
     # For now, we do not support mixed dummy and grad prompts
     # For now, we do not support mixed dummy and grad prompts
     # Concat in num_layer dimension
     # Concat in num_layer dimension
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
     grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
+
+    if grad_outputs_dtype is not None:
+        grad_outputs = [tensor.to(device=grad_outputs_device, dtype=grad_outputs_dtype) for tensor in grad_outputs]
+    if grad_prompts is not None:
+        grad_prompts = grad_prompts.to(device=prompts_device, dtype=prompts_dtype)
     return grad_outputs, grad_prompts
     return grad_outputs, grad_prompts