Spaces:
Running
Running
fix input format conversion
Browse files
ece.py
CHANGED
|
@@ -21,7 +21,6 @@ from torchmetrics.functional.classification.calibration_error import (
|
|
| 21 |
binary_calibration_error,
|
| 22 |
multiclass_calibration_error,
|
| 23 |
)
|
| 24 |
-
from numpy import ndarray
|
| 25 |
|
| 26 |
|
| 27 |
_CITATION = """\
|
|
@@ -109,15 +108,21 @@ class ECE(evaluate.Metric):
|
|
| 109 |
predictions = Tensor(predictions)
|
| 110 |
references = LongTensor(references)
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
# Compute the calibration
|
| 117 |
-
if
|
| 118 |
-
ece = multiclass_calibration_error(predictions, references, **kwargs)
|
| 119 |
-
else:
|
| 120 |
ece = binary_calibration_error(predictions, references, **kwargs)
|
|
|
|
|
|
|
| 121 |
return {
|
| 122 |
"ece": float(ece),
|
| 123 |
}
|
|
|
|
| 21 |
binary_calibration_error,
|
| 22 |
multiclass_calibration_error,
|
| 23 |
)
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
_CITATION = """\
|
|
|
|
| 108 |
predictions = Tensor(predictions)
|
| 109 |
references = LongTensor(references)
|
| 110 |
|
| 111 |
+
# Determine number of classes / binary or multiclass
|
| 112 |
+
binary = True
|
| 113 |
+
if "num_classes" not in kwargs:
|
| 114 |
+
max_label = int(amax(references, list(range(references.dim()))))
|
| 115 |
+
if max_label > 1:
|
| 116 |
+
kwargs["num_classes"] = max_label
|
| 117 |
+
binary = False
|
| 118 |
+
elif kwargs["num_classes"] > 1:
|
| 119 |
+
binary = False
|
| 120 |
|
| 121 |
# Compute the calibration
|
| 122 |
+
if binary:
|
|
|
|
|
|
|
| 123 |
ece = binary_calibration_error(predictions, references, **kwargs)
|
| 124 |
+
else:
|
| 125 |
+
ece = multiclass_calibration_error(predictions, references, **kwargs)
|
| 126 |
return {
|
| 127 |
"ece": float(ece),
|
| 128 |
}
|