|
from typing import Any |
|
|
|
import torch |
|
|
|
|
|
class VisualGLMBasePostProcessor: |
|
"""Base post processor for VisualGLM.""" |
|
|
|
def __init__(self) -> None: |
|
pass |
|
|
|
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str: |
|
return tokenizer.decode(output_token) |
|
|
|
|
|
class VisualGLMVSRPostProcessor(VisualGLMBasePostProcessor): |
|
"""VSR post processor for VisualGLM.""" |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
def __call__(self, output_token: torch.tensor, tokenizer: Any) -> str: |
|
output_text = tokenizer.decode(output_token) |
|
if 'yes' in output_text.lower(): |
|
return 'yes' |
|
elif 'no' in output_text.lower(): |
|
return 'no' |
|
else: |
|
return 'unknown' |
|
|