ucaslcl commited on
Commit
f89d84f
1 Parent(s): ff17e3b

Update modeling_GOT.py

Browse files
Files changed (1) hide show
  1. modeling_GOT.py +13 -14
modeling_GOT.py CHANGED
@@ -575,17 +575,16 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
575
  stopping_criteria=[stopping_criteria]
576
  )
577
 
578
-
 
 
 
 
 
579
  if render:
580
  print('==============rendering===============')
581
  from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
582
 
583
- outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
584
-
585
- if outputs.endswith(stop_str):
586
- outputs = outputs[:-len(stop_str)]
587
- outputs = outputs.strip()
588
-
589
  if '**kern' in outputs:
590
  import verovio
591
  from cairosvg import svg2png
@@ -813,16 +812,16 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
813
  max_new_tokens=4096,
814
  stopping_criteria=[stopping_criteria]
815
  )
816
-
 
 
 
 
 
 
817
  if render:
818
  print('==============rendering===============')
819
  from .render_tools import content_mmd_to_html
820
- outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
821
-
822
- if outputs.endswith(stop_str):
823
- outputs = outputs[:-len(stop_str)]
824
- outputs = outputs.strip()
825
-
826
  html_path_2 = save_render_file
827
  right_num = outputs.count('\\right')
828
  left_num = outputs.count('\left')
 
575
  stopping_criteria=[stopping_criteria]
576
  )
577
 
578
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
579
+
580
+ if outputs.endswith(stop_str):
581
+ outputs = outputs[:-len(stop_str)]
582
+ outputs = outputs.strip()
583
+
584
  if render:
585
  print('==============rendering===============')
586
  from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
587
 
 
 
 
 
 
 
588
  if '**kern' in outputs:
589
  import verovio
590
  from cairosvg import svg2png
 
812
  max_new_tokens=4096,
813
  stopping_criteria=[stopping_criteria]
814
  )
815
+
816
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
817
+
818
+ if outputs.endswith(stop_str):
819
+ outputs = outputs[:-len(stop_str)]
820
+ outputs = outputs.strip()
821
+
822
  if render:
823
  print('==============rendering===============')
824
  from .render_tools import content_mmd_to_html
 
 
 
 
 
 
825
  html_path_2 = save_render_file
826
  right_num = outputs.count('\\right')
827
  left_num = outputs.count('\left')