Update modeling_GOT.py
Browse files- modeling_GOT.py +13 -14
modeling_GOT.py
CHANGED
@@ -575,17 +575,16 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
575 |
stopping_criteria=[stopping_criteria]
|
576 |
)
|
577 |
|
578 |
-
|
|
|
|
|
|
|
|
|
|
|
579 |
if render:
|
580 |
print('==============rendering===============')
|
581 |
from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
|
582 |
|
583 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
584 |
-
|
585 |
-
if outputs.endswith(stop_str):
|
586 |
-
outputs = outputs[:-len(stop_str)]
|
587 |
-
outputs = outputs.strip()
|
588 |
-
|
589 |
if '**kern' in outputs:
|
590 |
import verovio
|
591 |
from cairosvg import svg2png
|
@@ -813,16 +812,16 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
813 |
max_new_tokens=4096,
|
814 |
stopping_criteria=[stopping_criteria]
|
815 |
)
|
816 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
817 |
if render:
|
818 |
print('==============rendering===============')
|
819 |
from .render_tools import content_mmd_to_html
|
820 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
821 |
-
|
822 |
-
if outputs.endswith(stop_str):
|
823 |
-
outputs = outputs[:-len(stop_str)]
|
824 |
-
outputs = outputs.strip()
|
825 |
-
|
826 |
html_path_2 = save_render_file
|
827 |
right_num = outputs.count('\\right')
|
828 |
left_num = outputs.count('\left')
|
|
|
575 |
stopping_criteria=[stopping_criteria]
|
576 |
)
|
577 |
|
578 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
579 |
+
|
580 |
+
if outputs.endswith(stop_str):
|
581 |
+
outputs = outputs[:-len(stop_str)]
|
582 |
+
outputs = outputs.strip()
|
583 |
+
|
584 |
if render:
|
585 |
print('==============rendering===============')
|
586 |
from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
|
587 |
|
|
|
|
|
|
|
|
|
|
|
|
|
588 |
if '**kern' in outputs:
|
589 |
import verovio
|
590 |
from cairosvg import svg2png
|
|
|
812 |
max_new_tokens=4096,
|
813 |
stopping_criteria=[stopping_criteria]
|
814 |
)
|
815 |
+
|
816 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
817 |
+
|
818 |
+
if outputs.endswith(stop_str):
|
819 |
+
outputs = outputs[:-len(stop_str)]
|
820 |
+
outputs = outputs.strip()
|
821 |
+
|
822 |
if render:
|
823 |
print('==============rendering===============')
|
824 |
from .render_tools import content_mmd_to_html
|
|
|
|
|
|
|
|
|
|
|
|
|
825 |
html_path_2 = save_render_file
|
826 |
right_num = outputs.count('\\right')
|
827 |
left_num = outputs.count('\left')
|