glenn-jocher
commited on
Commit
•
ec2da4a
1
Parent(s):
46e1fdf
Add ConfusionMatrix `normalize=True` flag (#3586)
Browse files- utils/metrics.py +4 -3
utils/metrics.py
CHANGED
@@ -158,11 +158,12 @@ class ConfusionMatrix:
|
|
158 |
def matrix(self):
|
159 |
return self.matrix
|
160 |
|
161 |
-
def plot(self, save_dir='', names=()):
|
162 |
try:
|
163 |
import seaborn as sn
|
164 |
-
|
165 |
-
|
|
|
166 |
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
167 |
|
168 |
fig = plt.figure(figsize=(12, 9), tight_layout=True)
|
|
|
158 |
def matrix(self):
|
159 |
return self.matrix
|
160 |
|
161 |
+
def plot(self, normalize=True, save_dir='', names=()):
|
162 |
try:
|
163 |
import seaborn as sn
|
164 |
+
|
165 |
+
if normalize:
|
166 |
+
array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize columns
|
167 |
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
168 |
|
169 |
fig = plt.figure(figsize=(12, 9), tight_layout=True)
|