|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
from model import SALMONN |
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--device", type=str, default="cuda:0") |
|
parser.add_argument("--ckpt_path", type=str, default=None) |
|
parser.add_argument("--whisper_path", type=str, default=None) |
|
parser.add_argument("--beats_path", type=str, default=None) |
|
parser.add_argument("--vicuna_path", type=str, default=None) |
|
parser.add_argument("--lora_alpha", type=int, default=32) |
|
parser.add_argument("--low_resource", action='store_true', default=False) |
|
parser.add_argument("--debug", action="store_true", default=False) |
|
|
|
args = parser.parse_args() |
|
|
|
model = SALMONN( |
|
ckpt=args.ckpt_path, |
|
whisper_path=args.whisper_path, |
|
beats_path=args.beats_path, |
|
vicuna_path=args.vicuna_path, |
|
lora_alpha=args.lora_alpha, |
|
low_resource=args.low_resource |
|
) |
|
model.to(args.device) |
|
model.eval() |
|
while True: |
|
print("=====================================") |
|
wav_path = input("Your Wav Path:\n") |
|
prompt = input("Your Prompt:\n") |
|
try: |
|
print("Output:") |
|
print(model.generate(wav_path, prompt=prompt)[0]) |
|
except Exception as e: |
|
print(e) |
|
if args.debug: |
|
import pdb; pdb.set_trace() |
|
|