File size: 2,491 Bytes
f7a83c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import json

from torch.utils.data import Dataset
from tqdm import tqdm
import base64
from io import BytesIO
from PIL import Image
from torchvision import transforms
from loguru import logger


class CaptionDataset(Dataset):

    def __init__(self, caption_file, image_file):
        logger.info('loading data from:{} and {}'.format(caption_file, image_file))
        # 读取每个图片的内容
        image_id2content = {}
        with open(image_file, 'r', encoding='utf8') as f:
            lines = f.readlines()
            for line in tqdm(lines):
                image_id, image_content = line.split('\t')
                image_id2content[image_id] = image_content

        # 读取每个图片的所有caption,得到所有训练数据
        data_list = []
        with open(caption_file, 'r', encoding='utf8') as f:
            lines = f.readlines()
            for line in tqdm(lines):
                line = json.loads(line)
                image_id = line['image_id']
                captions = line['text']
                for caption in captions:
                    data = {'caption': caption, 'image_base64': image_id2content[image_id], 'image_id': image_id}
                    data_list.append(data)

        logger.info('len of data:{}'.format(len(data_list)))
        self.data_list = data_list

        mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
        resolution = 256
        patch_resize_transform = transforms.Compose([
            lambda image: image.convert("RGB"),
            transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ])
        self.patch_resize_transform = patch_resize_transform

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        row = self.data_list[index]
        caption = row['caption'].strip()
        image_base64 = row['image_base64']
        image_id = row['image_id']

        # 加载图片,并进行预处理
        try:
            image = Image.open(BytesIO(base64.urlsafe_b64decode(image_base64)))
            patch_image = self.patch_resize_transform(image).unsqueeze(0)
        except Exception as e:
            # 图片加载失败
            logger.info('open image error, image_id: {}'.format(image_id))
            logger.info(e)
            patch_image = None

        data = {'patch_image': patch_image, 'caption': caption}
        return data