Spaces:
Running
Running
| # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Dict | |
| import evaluate | |
| import datasets | |
| from torch import Tensor, LongTensor | |
| from torchmetrics.functional.classification.calibration_error import ( | |
| binary_calibration_error, | |
| multiclass_calibration_error, | |
| ) | |
| _CITATION = """\ | |
| @InProceedings{huggingface:ece, | |
| title = {Expected calibration error (ECE)}, | |
| authors={Nathan Fradet}, | |
| year={2023} | |
| } | |
| """ | |
| _DESCRIPTION = """\ | |
| This metrics computes the expected calibration error (ECE). | |
| It directly calls the torchmetrics package: | |
| https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html | |
| """ | |
| _KWARGS_DESCRIPTION = """ | |
| Calculates how good are predictions given some references, using certain scores | |
| Args: | |
| predictions: list of predictions to score. They must have a shape (N,C,...) if multiclass, or (N,...) if binary. | |
| references: list of reference for each prediction, with a shape (N,...). | |
| Returns: | |
| ece: expected calibration error | |
| Examples: | |
| >>> ece = evaluate.load("Natooz/ece") | |
| >>> results = ece.compute( | |
| ... references=np.array([[0.25, 0.20, 0.55], | |
| ... [0.55, 0.05, 0.40], | |
| ... [0.10, 0.30, 0.60], | |
| ... [0.90, 0.05, 0.05]]), | |
| ... predictions=np.array(), | |
| ... num_classes=3, | |
| ... n_bins=3, | |
| ... norm="l1", | |
| ... ) | |
| >>> print(results) | |
| {'ece': 0.2000} | |
| """ | |
| class ECE(evaluate.Metric): | |
| """ | |
| Proxy to the BinaryCalibrationError (ECE) metric of the torchmetrics package: | |
| https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html | |
| """ | |
| def _info(self): | |
| return evaluate.MetricInfo( | |
| # This is the description that will appear on the modules page. | |
| module_type="metric", | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| # This defines the format of each prediction and reference | |
| features=datasets.Features( | |
| { | |
| "predictions": datasets.Sequence(datasets.Value("float32")), | |
| "references": datasets.Value("int64"), | |
| } | |
| ), | |
| # Homepage of the module for documentation | |
| homepage="https://huggingface.co/spaces/Natooz/ece", | |
| # Additional links to the codebase or references | |
| codebase_urls=[ | |
| "https://github.com/Lightning-AI/torchmetrics/blob/v0.11.4/src/torchmetrics/classification/calibration_error.py" | |
| ], | |
| reference_urls=[ | |
| "https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html" | |
| ], | |
| ) | |
| def _compute(self, predictions=None, references=None, **kwargs) -> Dict[str, float]: | |
| """Returns the ece. | |
| See the torchmetrics documentation for more information on the arguments to pass. | |
| https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html | |
| predictions: (N,C,...) if multiclass or (N,...) if binary | |
| references: (N,...) | |
| If "num_classes" is not provided in a multiclasses setting, the number maximum label index will | |
| be used as "num_classes". | |
| """ | |
| # Convert the input | |
| predictions = Tensor(predictions) | |
| references = LongTensor(references) | |
| # Determine number of classes / binary or multiclass | |
| error_msg = "Expected to have predictions with shape (N,C,...) for multiclass or (N,...) for binary, " \ | |
| f"and references with shape (N,...), but got {predictions.shape} and {references.shape}" | |
| binary = True | |
| if predictions.dim() == references.dim() + 1: # multiclass | |
| binary = False | |
| if "num_classes" not in kwargs: | |
| kwargs["num_classes"] = int(predictions.shape[1]) | |
| elif predictions.dim() == references.dim() and "num_classes" in kwargs: | |
| raise ValueError("You gave the num_classes argument, with predictions and references having the" | |
| "same number of dimensions. " + error_msg) | |
| elif predictions.dim() != references.dim(): | |
| raise ValueError("Bad input shape. " + error_msg) | |
| # Compute the calibration | |
| if binary: | |
| ece = binary_calibration_error(predictions, references, **kwargs) | |
| else: | |
| ece = multiclass_calibration_error(predictions, references, **kwargs) | |
| return { | |
| "ece": float(ece), | |
| } | |