glenn-jocher commited on
Commit
ec2da4a
1 Parent(s): 46e1fdf

Add ConfusionMatrix `normalize=True` flag (#3586)

Browse files
Files changed (1) hide show
  1. utils/metrics.py +4 -3
utils/metrics.py CHANGED
@@ -158,11 +158,12 @@ class ConfusionMatrix:
158
  def matrix(self):
159
  return self.matrix
160
 
161
- def plot(self, save_dir='', names=()):
162
  try:
163
  import seaborn as sn
164
-
165
- array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
 
166
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
167
 
168
  fig = plt.figure(figsize=(12, 9), tight_layout=True)
 
158
  def matrix(self):
159
  return self.matrix
160
 
161
+ def plot(self, normalize=True, save_dir='', names=()):
162
  try:
163
  import seaborn as sn
164
+
165
+ if normalize:
166
+ array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize columns
167
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
168
 
169
  fig = plt.figure(figsize=(12, 9), tight_layout=True)