Преглед изворни кода

remove torch_dtype on converted model

justheuristic пре 3 година
родитељ
комит
05fed964a3
1 измењених фајлова са 1 додато и 1 уклоњено
  1. 1 1
      tests/test_full_model.py

+ 1 - 1
tests/test_full_model.py

@@ -13,7 +13,7 @@ logger = get_logger(__file__)
 @pytest.mark.forked
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype="auto")
+    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.n_layer