GigaCheck-Detector-Multi / configuration_gigacheck.py
iitolstykh's picture
Upload 2 files
ef31f0e verified
raw
history blame contribute delete
858 Bytes
from typing import Dict, Optional, Any
from transformers import MistralConfig
class GigaCheckConfig(MistralConfig):
def __init__(
self,
with_detr: bool = False,
detr_config: Optional[Dict[str, Any]] = None,
freeze_backbone: bool = False,
id2label: Dict[int, str] = None,
num_labels: int = 2,
max_length: int = 1024,
conf_interval_thresh=0.8,
**kwargs
):
super().__init__(**kwargs)
self.with_detr = with_detr
self.detr_config = detr_config
self.freeze_backbone = freeze_backbone
self.id2label = id2label
self.num_labels = num_labels
self.max_length = max_length
self.conf_interval_thresh = conf_interval_thresh
if self.id2label:
self.id2label = {int(k): v for k, v in self.id2label.items()}