Ver Fonte

[minor fix] make DistributedBloomForSequenceClassification to not create the model during init

Dmitry Baranchuk há 3 anos atrás
pai
commit
98d8d48267
1 ficheiros alterados com 6 adições e 2 exclusões
  1. 6 2
      src/client/remote_model.py

+ 6 - 2
src/client/remote_model.py

@@ -187,8 +187,12 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
     config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):
-        super().__init__(config)
-        self.transformer = DistributedBloomModel(config)
+        BloomPreTrainedModel.__init__(self, config)
+        self.num_labels = config.num_labels
 
+        self.transformer = DistributedBloomModel(config)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+        
         # Initialize weights and apply final processing
         self.post_init()
+