File size: 757 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from typing import Any
import torch
class VisualGLMBasePostProcessor:
"""Base post processor for VisualGLM."""
def __init__(self) -> None:
pass
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
return tokenizer.decode(output_token)
class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor):
"""VSR post processor for VisualGLM."""
def __init__(self) -> None:
super().__init__()
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str:
output_text = tokenizer.decode(output_token)
if 'yes' in output_text.lower():
return 'yes'
elif 'no' in output_text.lower():
return 'no'
else:
return 'unknown'
|