FAPM / lavis /tasks /dialogue.py
wenkai's picture
Upload 560 files
4b532c0 verified
raw
history blame
4.05 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import json
import os
from lavis.common.dist_utils import main_process
from lavis.common.logger import MetricLogger
from lavis.common.registry import registry
from lavis.tasks.base_task import BaseTask
from lavis.datasets.data_utils import prepare_sample
import numpy as np
@registry.register_task("dialogue")
class DialogueTask(BaseTask):
def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):
super().__init__()
self.num_beams = num_beams
self.max_len = max_len
self.min_len = min_len
self.evaluate = evaluate
self.report_metric = report_metric
@classmethod
def setup_task(cls, cfg):
run_cfg = cfg.run_cfg
num_beams = run_cfg.num_beams
max_len = run_cfg.max_len
min_len = run_cfg.min_len
evaluate = run_cfg.evaluate
report_metric = run_cfg.get("report_metric", True)
return cls(
num_beams=num_beams,
max_len=max_len,
min_len=min_len,
evaluate=evaluate,
report_metric=report_metric,
)
def valid_step(self, model, samples):
results = []
loss = model(samples)["loss"].item()
return [loss]
def after_evaluation(self, val_result, split_name, epoch, **kwargs):
if self.report_metric:
avg_loss = np.mean(val_result)
metrics = {"agg_metrics": avg_loss}
else:
metrics = {"agg_metrics": 0.0}
return metrics
@main_process
def _report_metrics(self, eval_result_file, split_name):
# TODO better way to define this
coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt")
coco_val = coco_dialogue_eval(coco_gt_root, eval_result_file, split_name)
agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"]
log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}
with open(
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
) as f:
f.write(json.dumps(log_stats) + "\n")
coco_res = {k: v for k, v in coco_val.eval.items()}
coco_res["agg_metrics"] = agg_metrics
return coco_res
# TODO better structure for this.
from pycocoevalcap.eval import COCOEvalCap
from pycocotools.coco import COCO
from torchvision.datasets.utils import download_url
def coco_dialogue_eval(coco_gt_root, results_file, split):
urls = {
"val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json",
"test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json",
}
filenames = {
"val": "coco_karpathy_val_gt.json",
"test": "coco_karpathy_test_gt.json",
}
download_url(urls[split], coco_gt_root)
annotation_file = os.path.join(coco_gt_root, filenames[split])
# create coco object and coco_result object
coco = COCO(annotation_file)
coco_result = coco.loadRes(results_file)
# create coco_eval object by taking coco and coco_result
coco_eval = COCOEvalCap(coco, coco_result)
# evaluate on a subset of images by setting
# coco_eval.params['image_id'] = coco_result.getImgIds()
# please remove this line when evaluating the full validation set
# coco_eval.params['image_id'] = coco_result.getImgIds()
# evaluate results
# SPICE will take a few minutes the first time, but speeds up due to caching
coco_eval.evaluate()
# print output evaluation scores
for metric, score in coco_eval.eval.items():
print(f"{metric}: {score:.3f}")
return coco_eval