浏览代码

[minor fix] make DistributedBloomForSequenceClassification to not create the model during init

Dmitry Baranchuk 3 年之前
父节点
当前提交
98d8d48267
共有 1 个文件被更改,包括 6 次插入2 次删除
  1. 6 2
      src/client/remote_model.py

+ 6 - 2
src/client/remote_model.py

@@ -187,8 +187,12 @@ class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
     config_class = DistributedBloomConfig
 
     def __init__(self, config: DistributedBloomConfig):
-        super().__init__(config)
-        self.transformer = DistributedBloomModel(config)
+        BloomPreTrainedModel.__init__(self, config)
+        self.num_labels = config.num_labels
 
+        self.transformer = DistributedBloomModel(config)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+        
         # Initialize weights and apply final processing
         self.post_init()
+