Skip to content

Commit bbd7669

Browse files
committed
fix comments 2
Signed-off-by: Aleksandr Laptev <[email protected]>
1 parent 02a5e7a commit bbd7669

File tree

4 files changed

+1455
-4
lines changed

4 files changed

+1455
-4
lines changed

docs/source/starthere/tutorials.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ To run a tutorial:
106106
* - ASR
107107
- Multi-lingual ASR
108108
- `Multi-lingual ASR <https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/asr/Multilang_ASR.ipynb>`_
109+
* - ASR
110+
- ASR Confidence Estimation
111+
- `ASR Confidence Estimation <https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/asr/ASR_Confidence_Estimation.ipynb>`_
109112
* - NLP
110113
- Using Pretrained Language Models for Downstream Tasks
111114
- `Pretrained Language Models for Downstream Tasks <https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/nlp/01_Pretrained_Language_Models_for_Downstream_Tasks.ipynb>`_

nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,20 @@ def autocast():
147147
save_nt_curve(y_true, y_score, plot_dir, level + "_" + "nt")
148148
# AUC-YC curve
149149
yc_thresholds, yc_values = result_yc[-1]
150-
save_custom_confidence_curve(yc_thresholds, yc_values, plot_dir, level + "_" + "yc")
150+
save_custom_confidence_curve(
151+
yc_thresholds,
152+
yc_values,
153+
plot_dir,
154+
level + "_" + "yc",
155+
"Threshold",
156+
"True positive rate − False Positive Rate",
157+
)
151158
# ECE curve
152159
ece_thresholds, ece_values = results_ece[-1]
153160
ece_values /= max(ece_values)
154-
save_custom_confidence_curve(ece_thresholds, ece_values, plot_dir, level + "_" + "ece")
161+
save_custom_confidence_curve(
162+
ece_thresholds, ece_values, plot_dir, level + "_" + "ece", "Threshold", "|Accuracy − Confidence score|"
163+
)
155164

156165
return results
157166

nemo/collections/asr/parts/utils/confidence_metrics.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import math
1616
import os
1717
from pathlib import Path
18-
from typing import List, Tuple, Union
18+
from typing import List, Optional, Tuple, Union
1919

2020
import matplotlib.pyplot as plt
2121
import numpy as np
@@ -193,6 +193,8 @@ def save_confidence_hist(y_score: Union[List[float], np.ndarray], plot_dir: Unio
193193
os.makedirs(plot_dir, exist_ok=True)
194194
plt.hist(np.array(y_score), 50, range=(0, 1))
195195
plt.title(name)
196+
plt.xlabel("Confidence score")
197+
plt.ylabel("Count")
196198
plt.savefig(Path(plot_dir) / Path(name + ".png"), dpi=300)
197199
plt.clf()
198200

@@ -247,12 +249,18 @@ def save_custom_confidence_curve(
247249
values: Union[List[float], np.ndarray],
248250
plot_dir: Union[str, Path],
249251
name: str = "my_awesome_curve",
252+
xlabel: Optional[str] = None,
253+
ylabel: Optional[str] = None,
250254
):
251255
assert len(thresholds) == len(values)
252256
os.makedirs(plot_dir, exist_ok=True)
253257
plt.plot(thresholds, values)
254258
plt.xlim([0, 1])
255259
plt.ylim([0, 1])
256260
plt.title(name)
261+
if xlabel is not None:
262+
plt.xlabel(xlabel)
263+
if ylabel is not None:
264+
plt.ylabel(ylabel)
257265
plt.savefig(Path(plot_dir) / Path(name + ".png"), dpi=300)
258266
plt.clf()

tutorials/asr/ASR_Confidence_Estimation.ipynb

Lines changed: 1432 additions & 1 deletion
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)