import torch device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class CTCLabelConverter(object): """ Convert between text-label and text-index """ def __init__(self, character): # character (str): set of the possible characters. dict_character = list(character) self.dict = {} for i, char in enumerate(dict_character): # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss self.dict[char] = i + 1 self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) def encode(self, text, batch_max_length=25): """convert text-label into text-index. input: text: text labels of each image. [batch_size] batch_max_length: max length of text label in the batch. 25 by default output: text: text index for CTCLoss. [batch_size, batch_max_length] length: length of each text. [batch_size] """ length = [len(s) for s in text] # The index used for padding (=0) would not affect the CTC loss calculation. batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) for i, t in enumerate(text): text = list(t) text = [self.dict[char] for char in text] batch_text[i][:len(text)] = torch.LongTensor(text) return (batch_text.to(device), torch.IntTensor(length).to(device)) def decode(self, text_index, length): """ convert text-index into text-label. """ texts = [] for index, l in enumerate(length): t = text_index[index, :] char_list = [] for i in range(l): if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. char_list.append(self.character[t[i]]) text = ''.join(char_list) texts.append(text) return texts class CTCLabelConverterForBaiduWarpctc(object): """ Convert between text-label and text-index for baidu warpctc """ def __init__(self, character): # character (str): set of the possible characters. dict_character = list(character) self.dict = {} for i, char in enumerate(dict_character): # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss self.dict[char] = i + 1 self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) def encode(self, text, batch_max_length=25): """convert text-label into text-index. input: text: text labels of each image. [batch_size] output: text: concatenated text index for CTCLoss. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] length: length of each text. [batch_size] """ length = [len(s) for s in text] text = ''.join(text) text = [self.dict[char] for char in text] return (torch.IntTensor(text), torch.IntTensor(length)) def decode(self, text_index, length): """ convert text-index into text-label. """ texts = [] index = 0 for l in length: t = text_index[index:index + l] char_list = [] for i in range(l): if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. char_list.append(self.character[t[i]]) text = ''.join(char_list) texts.append(text) index += l return texts class AttnLabelConverter(object): """ Convert between text-label and text-index """ def __init__(self, character): # character (str): set of the possible characters. # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] list_character = list(character) self.character = list_token + list_character self.dict = {} for i, char in enumerate(self.character): # print(i, char) self.dict[char] = i def encode(self, text, batch_max_length=25): """ convert text-label into text-index. input: text: text labels of each image. [batch_size] batch_max_length: max length of text label in the batch. 25 by default output: text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] """ length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. # batch_max_length = max(length) # this is not allowed for multi-gpu setting batch_max_length += 1 # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) for i, t in enumerate(text): text = list(t) text.append('[s]') text = [self.dict[char] for char in text] batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token return (batch_text.to(device), torch.IntTensor(length).to(device)) def decode(self, text_index, length): """ convert text-index into text-label. """ texts = [] for index, l in enumerate(length): text = ''.join([self.character[i] for i in text_index[index, :]]) texts.append(text) return texts class Averager(object): """Compute average for torch.Tensor, used for loss average.""" def __init__(self): self.reset() def add(self, v): count = v.data.numel() v = v.data.sum() self.n_count += count self.sum += v def reset(self): self.n_count = 0 self.sum = 0 def val(self): res = 0 if self.n_count != 0: res = self.sum / float(self.n_count) return res