ryanramos commited on
Commit
d1b8c9b
1 Parent(s): 0df51f9

Add source code

Browse files
__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
answer_list.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["net", "pitcher", "orange", "yes", "white", "skiing", "red", "frisbee", "brushing teeth", "no", "black and white", "skateboard", "1", "blue", "green", "motorcycle", "gray", "2", "purse", "skis", "poles", "surfboard", "dog", "on", "office", "large", "very big", "laptop", "vent", "computer", "black", "bear", "3", "wii", "glasses", "tree", "eating", "log", "5", "raft", "left", "living room", "pink", "right", "railing", "grass", "wire", "10 years", "knife", "cake", "banana", "chef", "vanilla", "4", "outdoor", "mustard", "bun", "clouds", "dock", "brown", "silver", "refrigerator", "square", "teddy", "elm", "stripes", "baseball", "catcher", "beer", "bottom", "north", "nike", "yellow and white", "morning", "elephant", "red and white", "propeller", "tan", "wall", "rolex", "clock", "table", "0", "wood", "christmas", "spinach", "thick", "bag", "leaves", "necklace", "6", "bathroom", "shower", "towel", "solid", "referee", "wilson", "8:00", "e", "24", "hat", "grazing", "sheep", "10", "tag", "spanish", "hot dog", "plate", "lunch", "butter", "peppers", "onions", "very", "mayonnaise", "mayo", "sweet potato", "pig", "sweet", "flowers", "floral", "yellow", "window", "7", "pizza", "car", "cargo", "stairs", "abstract", "rug", "baseball cap", "texting", "pole", "crosswalk", "nothing", "urban", "bus", "light", "afternoon", "boat", "cheese", "paper", "real", "sun", "birthday", "words", "inside", "shadows", "tomato", "evergreen", "100 feet", "shingles", "trees", "building", "hay", "ski pole", "patterned", "walking", "ice", "laundry", "pepsi", "good", "1:50", "purple", "13", "africa", "teddy bears", "socks", "giraffe", "soccer", "blue and yellow", "zebras", "cupcake", "broccoli", "soldier", "parking lot", "cows", "herding", "on table", "fish", "nightstand", "50", "overcast", "cross", "toaster oven", "tile", "11:55", "red and yellow", "nowhere", "hair dryer", "truck", "11", "people", "rectangle", "hot dogs", "party", "12:55", "apron", "kitchen", "cooking", "ring", "1 way", "stop", "neither", "many", "female", "brushing", "tie", "tennis racket", "knife and fork", "restaurant", "cat", "bed", "sand", "ocean", "cold", "kites", "cumulus", "standing", "male", "star", "tracks", "chocolate", "round", "fork and knife", "yankees", "pictures", "dots", "bird", "parrot", "red white and blue", "man", "metal", "fence", "snowboarding", "pine", "snow", "shorts", "swim", "wine", "brick", "no parking", "children", "beef", "phone", "english", "cell phone", "pink and yellow", "clear", "watermelon", "bedroom", "fork", "cow", "rackets", "tennis rackets", "8", "collar", "tennis", "1950s", "playing tennis", "skirt", "30", "polka dot", "beach", "horse", "grill", "african american", "down", "street", "in air", "sweater", "yellow and blue", "park", "backyard", "spectators", "parasailing", "31", "river", "55", "shadow", "winter", "chicken", "tea", "evening", "dusk", "ski resort", "helmet", "penne", "bench", "resting", "elephants", "southwest", "usa", "cars", "town", "bananas", "umbrella", "container", "woman", "on counter", "salad", "striped", "motel", "vertical", "oranges", "hot sauce", "bottle", "juice", "eyes", "ground", "backpack", "black and yellow", "forward", "jackets", "1 on right", "green and yellow", "playing baseball", "riding", "sitting", "carrot", "basket", "seagull", "ski poles", "p", "parking", "street light", "mets", "strap", "bike", "riding bike", "poodle", "shoes", "carpet", "lettuce", "food", "1 foot", "roses", "mountains", "scissors", "camera", "beige", "beard", "cutting", "baby", "tape", "watch", "never", "taking picture", "eggs", "syrup", "sandwich", "water skiing", "microphone", "back", "bears", "donuts", "w", "sky", "double decker", "england", "surfing", "running", "shirt", "barn", "weather vane", "white and blue", "fishing", "bridge", "los angeles", "open", "red sox", "bat", "plane", "white and green", "transportation", "sunny", "bus stop", "city", "brown and white", "bicycle", "crow", "magazines", "daisy", "14", "old", "curtains", "jumped", "snowboard", "dinosaur", "racing", "asphalt", "court", "plastic", "circle", "red and blue", "zebra", "12", "biplane", "shallow", "brazil", "logo", "2:20", "electric", "night time", "motion", "toothbrushes", "orange and white", "66", "spoon", "toyota", "tennis shoes", "46", "second", "no 1", "iphone", "friend", "apple", "carnation", "15", "tiger", "glove", "airplane", "bow", "air france", "passengers", "tv", "on building", "3:55", "victorian", "steeple", "happy", "skateboarding", "fruit", "cutting board", "cantaloupe", "kiwi", "sliced", "heart", "water", "rainy", "carrots", "giraffes", "eat", "ramp", "lab", "field", "horizontal", "birds", "home", "shrimp", "12 feet", "girl", "modern", "turtle", "dell", "boots", "sunglasses", "black and orange", "yellow and black", "gloves", "hp", "desk", "both", "sign", "on street", "2000", "cirrus", "to dry", "ceiling", "fluorescent", "up", "9", "boys", "playing soccer", "american", "passenger", "turn", "palm", "no train", "wedding", "branch", "parrots", "air force", "on tracks", "small", "tank", "dirty", "france", "honda", "2.00", "whale", "vase", "flying", "professional", "driving", "tissue", "protest", "corona", "for balance", "twin", "clothes", "t shirt", "window sill", "wild", "noon", "caution", "spring", "raining", "cane", "school", "windsurfing", "parachute", "black and red", "25", "background", "toaster", "planes", "yellow and red", "spatula", "10:10", "ivory", "train", "welcome", "highway", "off", "on track", "electricity", "italy", "dinner", "sink", "squares", "5 ft", "parked", "store", "dress", "signs", "meow", "football", "rugby", "stainless steel", "la", "dirt", "blue and white", "klm", "house", "unknown", "ford", "reading", "chair", "mountain", "alive", "water skis", "picture", "parade", "slippers", "trailer", "boating", "holding it", "shade", "cloth", "6:20", "candle", "hose", "hand", "3:25", "on sidewalk", "poster", "downhill", "68", "reflection", "summer", "pickles", "halloween", "bats", "london", "zoo", "surfer", "racket", "flickr", "cutting hair", "strawberries", "mushroom", "teddy bear", "big", "suitcase", "veggie", "pepper", "houses", "70", "toshiba", "triangle", "boxes", "photograph", "smoke", "engine", "camel", "sidewalk", "left 1", "red and green", "4:35", "on couch", "candy", "minnie mouse", "homemade", "mouse", "box", "movie", "45", "strawberry", "fridge", "full", "vegetables", "bright", "play", "remote", "pond", "savannah", "celery", "concrete", "semi", "dump", "scania", "safety", "posing", "fabric", "laying", "couch", "blueberries", "handle", "pipe", "stick", "parmesan", "steak", "chain link", "catch", "barbed wire", "mozzarella", "soda", "fire hydrant", "cat food", "pepperoni", "lot", "licking", "red and black", "clay", "tennis court", "jumping", "potatoes", "toothbrush", "kite", "not at all", "flying kite", "broken", "black and silver", "lap", "outside", "44", "delta", "greyhound", "ring finger", "talking on phone", "bad", "kettle", "35", "motorcycles", "produce", "comfort", "steering wheel", "18", "humans", "coffee", "white and brown", "fall", "bread", "cherry", "4:30", "flag", "night", "lamp", "cucumber", "can't see", "porcelain", "oval", "museum", "rain", "sprinkles", "20", "kids", "bracelet", "sneakers", "mask", "mickey mouse", "twins", "very high", "costume", "cabbage", "paint", "lighting", "young", "air conditioner", "wooden", "board", "someone", "beets", "16", "day time", "4 inches", "lights", "ladder", "glass", "ferris wheel", "fries", "steamed", "shepherd", "cotton", "suit", "goatee", "on his head", "print", "happy birthday", "forks", "travel", "maple", "200", "oil", "jeans", "can", "chopsticks", "on wall", "construction", "mack", "36", "chinese", "moped", "festival", "gas", "throwing", "circus", "wires", "not possible", "plates", "sugar", "in", "women's", "door", "no man", "volleyball", "serving", "ponytail", "business", "decoration", "santa", "flat", "barrel", "12:15", "candles", "atv", "free", "hair", "waffle", "ball", "stop sign", "wetsuit", "very deep", "swimsuit", "green and black", "foreground", "stands", "china airlines", "flower", "300", "lobster", "on bench", "plaster", "phones", "sailboat", "apples", "road", "recently", "cones", "cactus", "rice", "vegetarian", "donut", "ketchup", "police", "mirror", "rock", "meat", "blinds", "cell phones", "china", "rust", "7:25", "stone", "vans", "middle", "eagle", "9:30", "ping pong", "microwave", "gmc", "umbrellas", "wrist", "cuddling", "laughing", "boy", "next to toilet", "tabby", "petting", "south", "40", "name tag", "checkered", "name", "slow", "cardboard", "windows", "croissant", "plain", "cookie", "on ground", "low", "water bottle", "goggles", "turkey", "pull", "shut", "kite flying", "bowl", "smile", "in bowl", "bush", "cloudy", "top left", "skateboarder", "coca cola", "pan", "drinking", "short", "floor", "thanksgiving", "radio", "drink", "on toilet", "bike rack", "bleachers", "train tracks", "horses", "far", "top", "toilet", "in water", "private", "nature", "checkers", "commercial", "stroller", "power", "stuffed animals", "uniforms", "japan", "liquor", "faucet", "green and orange", "corn", "sub", "white and yellow", "mercedes", "in sky", "tarp", "indian", "counter", "multicolored", "polar", "go", "now", "no number", "swimming", "bridle", "cowboy", "union station", "salt and pepper", "olives", "pizza cutter", "british airways", "nighttime", "domestic", "trolley", "australia", "tiles", "pug", "wicker", "british", "us airways express", "burton", "christmas tree", "napkin", "writing", "rocks", "hello kitty", "lacoste", "gold", "fan", "skateboards", "day", "on floor", "2008", "dark", "flying kites", "rural", "olympics", "bmw", "34", "factory", "denim", "typing", "for fun", "steel", "watching tv", "chevron", "driver", "baggage claim", "grapes", "f", "angels", "roof", "handlebars", "train station", "public", "oak", "sleeping", "canada", "on runway", "air canada", "on top", "tired", "blonde", "cups", "little", "adidas", "10 feet", "white and gray", "leaf", "fisheye", "forest", "war", "octagon", "raspberry", "helmets", "united states", "29", "noodles", "van", "long", "traveling", "luggage", "airport", "single", "pitching", "dugout", "garbage", "in street", "happiness", "cigarette", "on tower", "antelope", "graffiti", "skating", "on road", "curved", "red light", "washington", "ski lift", "athletics", "brace", "squatting", "catching", "batter", "batting", "game", "towards", "33", "sliding", "makeup", "japanese", "person", "pirates", "plaid", "rose", "daytime", "keyboard", "surfboards", "hummingbird", "ollie", "11:30", "clock tower", "5:55", "san francisco", "stopping", "tags", "samsung", "computers", "cabinets", "talking", "cage", "asparagus", "5 years", "hanger", "adult", "rabbit", "empty", "softball", "1st", "playing", "chairs", "farm", "cross country", "dump truck", "women", "snowboarder", "tall", "monkey", "mantle", "fire", "books", "quilt", "cessna", "chandelier", "dunkin donuts", "beans", "relish", "no flag", "parking meter", "spots", "ducks", "sandals", "doughnut", "lighthouse", "yacht", "german shepherd", "in middle", "raw", "chain", "2 feet", "pedestal", "sauerkraut", "bagels", "mutt", "dog and cat", "race", "poor", "cat and dog", "station", "printer", "daisies", "front", "gravel", "rear", "grassy", "pigeons", "dogs", "in car", "life", "wii remotes", "suv", "leather", "bottom right", "peace", "facebook", "blanket", "fountain", "frisbees", "12:30", "am", "scooter", "going", "analog", "america", "pitbull", "relaxing", "paddle boarding", "white and pink", "shampoo", "alps", "ride", "side", "mane", "on desk", "on chair", "2012", "multi", "straight", "big ben", "closed", "frosted", "3 feet", "waves", "buoy", "life vest", "trash can", "medium", "boxer", "very tall", "yamaha", "sunlight", "hit ball", "dry", "coke", "gym", "orange and black", "center", "rope", "flip flops", "4th of july", "siamese", "crafts", "color", "italian", "playing frisbee", "skate park", "orange juice", "windowsill", "corgi", "thumb", "peanut butter", "pie", "toast", "no hat", "benches", "diamond", "blender", "avocado", "television", "speakers", "pony", "baseball field", "pavement", "sydney", "not there", "diamonds", "4 feet", "goalie", "soccer ball", "runway", "video game", "gaming", "casual", "green and white", "toilet brush", "working", "pickup", "girls", "remotes", "pasta", "hood", "braves", "skier", "motorola", "17", "b", "100", "diet coke", "hospital", "wagon", "milk", "ferry", "rainbow", "on bed", "toward", "1:30", "19", "security", "herself", "mercedes benz", "supreme", "thin", "platform", "gray and red", "thai", "storage", "thailand", "swan", "peach", "10:05", "dome", "chiquita", "2:00", "mountain dew", "23", "knives", "street sign", "on beach", "playing wii", "using laptop", "stickers", "yogurt", "on grass", "9:50", "9:45", "sweat", "gatorade", "umpire", "37", "transport", "desktop", "desserts", "main", "boston", "fell", "top right", "case", "asleep", "over", "9:55", "grapefruit", "breakfast", "headphones", "freight", "cup", "sweatband", "nobody", "lamps", "9:25", "scarf", "on fridge", "main st", "moving", "confused", "fresh", "kiting", "blue jay", "flats", "long time", "chihuahua", "ceramic", "mushrooms", "on plate", "human", "power lines", "hotel", "map", "earring", "boarding", "display", "warm", "napkins", "brown and black", "broom", "basketball", "papers", "holding baby", "sad", "kickstand", "60", "shoulder", "sleep", "footprints", "tunnel", "1990", "hats", "6 inches", "ham", "bacon", "church", "53", "pineapple", "at camera", "red bull", "pilot", "tattoo", "work", "polar bear", "taking off", "website", "22", "4:00", "coffee maker", "fast", "fur", "rubber", "tongs", "german", "germany", "3 inches", "toy", "3:20", "calm", "pots", "balloons", "fruits", "9:20", "drawer", "oven", "soup", "stove", "heels", "wind", "island", "blood", "leg", "theater", "tennis racquet", "21", "gothic", "2:35", "wii remote", "turning", "20 feet", "pink and black", "ears", "fun", "wreath", "to right", "child", "fly", "head", "drywall", "shorter", "pier", "feeding giraffe", "in vase", "burger", "easter", "onion", "uniform", "remote control", "guitar", "time", "verizon", "tomatoes", "ship", "tulips", "glaze", "on suitcase", "tent", "1:45", "market", "bnsf", "bandana", "still", "don't know", "piano", "mouth", "run", "sparrow", "throw", "lines", "vest", "1950", "jet", "sepia", "2015", "busy", "lighter", "dessert", "bending", "75", "finch", "pastries", "outdoors", "bakery", "clean", "ipod", "tablecloth", "cigarettes", "looking at phone", "in front", "food truck", "face", "swinging", "safari", "500", "volkswagen", "2010", "shape", "shelves", "riding horses", "2016", "behind bus", "towels", "lemon", "straw", "bamboo", "5 feet", "hardwood", "oregon", "schnauzer", "organic", "h", "kid", "meter", "61", "charging", "bald", "caucasian", "man on left", "stand", "27", "dining room", "sandwiches", "32", "apartment", "tower", "virgin", "out", "white and red", "2:05", "i don't know", "chains", "legs", "age", "goats", "s", "congratulations", "dresser", "camper", "half", "silverware", "decorative", "hawaiian", "petting horse", "wheel", "florida", "reds", "washington dc", "moon", "conference", "screen", "controller", "robin", "men", "protection", "roll", "harley davidson", "coal", "mustache", "smiling", "pedestrians", "88", "me", "tray", "males", "monitor", "bell", "landscape", "club", "toothpick", "seagulls", "bowtie", "lake", "steam", "surf", "baseball glove", "blinders", "woods", "stuffed", "sunbathing", "shearing", "dad", "mixer", "pot", "blending", "identification", "owl", "wine glass", "on bike", "billabong", "new york", "yarn", "tube", "tennis ball", "2:55", "ice cream", "chevrolet", "shirt and tie", "taking selfie", "blue and green", "he isn't", "cutting cake", "east", "setting", "brewers", "riding bikes", "7 eleven", "stars", "jockey", "jacket", "standing still", "book", "gray and white", "pen", "red white blue", "above", "alaska", "tongue", "feathers", "k", "camping", "pasture", "corner", "away", "ski", "texas", "fire truck", "sailboats", "jump", "walk", "spray paint", "loading", "united", "1000", "brushing his teeth", "roman numerals", "garlic", "surprise", "3rd", "first", "side of road", "dodgers", "airplanes", "unsure", "russian", "wet", "skyscraper", "5 star", "brushing her teeth", "blankets", "natural", "across street", "smartphone", "duck", "sausage", "paris", "newspaper", "pants", "spices", "pillow", "to left", "snowboards", "colgate", "on elephant", "string", "horns", "2:40", "men's", "cobblestone", "regular", "staring", "28", "barber shop", "linoleum", "grind", "cut", "x", "above sink", "above stove", "dishes", "dalmatian", "watching", "glazed", "5:25", "j", "messy", "wallet", "tuna", "toasted", "grilled", "french", "green and blue", "sunflowers", "to catch frisbee", "wool", "sprint", "no grass", "cabinet", "shell", "foil", "bottles", "bar", "king", "paper towels", "friends", "beagle", "school bus", "laptops", "snowing", "cement", "pc", "accident", "stuffed animal", "wakeboard", "balance", "in suitcase", "white and black", "nikon", "cleats", "on sink", "pool", "mom", "downtown", "asian", "heater", "bathing", "193", "against wall", "canopy", "jungle", "berries", "military", "pickle", "clams", "seafood", "in box", "boats", "tables", "lizard", "lemonade", "m", "soft", "illinois", "country", "for sale", "arm", "listening", "curly", "play tennis", "hands", "cereal", "blue and red", "robe", "around neck", "red and silver", "soap", "trains", "throwing frisbee", "smoking", "india", "headband", "not very", "westin", "serve", "bicycles", "can't tell", "to catch ball", "visibility", "ana", "reins", "rodeo", "boot", "on horse", "12:35", "riding motorcycle", "mexico", "mother", "african", "left and right", "button", "earrings", "blackberry", "cell", "10:00", "harness", "pillows", "vegetable", "tablet", "fern", "cats", "golden retriever", "goat", "tractor", "valentine's day", "hearts", "khaki", "man on right", "mcdonald's", "player", "arriving", "husky", "on skateboard", "vases", "coat", "beanie", "coming", "granite", "shopping cart", "it's raining", "sports", "leash", "balls", "blurry", "baseball bat", "team", "mango", "mug", "eiffel tower", "worms", "trash", "robot", "show", "terrier", "painting", "rooster", "42", "jones", "state farm", "balloon", "trunk", "coach", "t", "playing game", "fireplace", "behind clouds", "uphill", "motocross", "sony", "magazine", "kitesurfing", "catching frisbee", "catch frisbee", "bud light", "drive", "fighting", "1 on left", "very old", "hallway", "lexus", "wii controller", "9:15", "fast food", "5:45", "catholic", "muffin", "traffic light", "band", "button up", "grocery", "shelf", "2:25", "honey", "plants", "oars", "foggy", "nathan's", "cord", "yard", "48", "donut shop", "chimney", "calico", "suits", "sideways", "animals", "black and blue", "bikini", "photographer", "700", "queen", "1:00", "12:05", "horseback riding", "awake", "bunny", "12:00", "continental", "flamingo", "rye", "family", "lots", "owner", "stew", "palm tree", "cruise ship", "56", "design", "ny", "far right", "tire", "younger", "biking", "at&t", "giants", "marshmallows", "caramel", "polo", "emirates", "salon", "focus", "on motorcycle", "magnets", "mat", "ivy", "cakes", "chrome", "bob", "asia", "graduation", "cauliflower", "in snow", "c", "rough", "vacation", "air", "windy", "victoria", "4:45", "trick", "coconut", "labrador", "on left", "yellow and green", "butterfly", "fake", "on napkin", "bricks", "wine glasses", "detroit", "man's", "parsley", "art", "subway", "wave", "placemat", "hydrant", "sofa", "pigeon", "riding elephant", "all", "branches", "plant", "to eat", "zucchini", "feta", "neon", "mouse pad", "cloud", "toilet paper", "pumpkin", "rowing", "toronto", "handicap", "seeds", "fly kite", "chicago", "marble", "frame", "150", "rocky", "give way", "sauce", "it's not", "control", "high chair", "playstation", "xbox", "not likely", "roman", "land", "1:35", "lifeguard", "on pizza", "size", "bull", "dandelions", "equestrian", "goose", "8 feet", "recessed", "statue", "index", "phillies", "strike", "mirrors", "pointing", "farmer", "collie", "motorbike", "lanes", "bikes", "biker", "arrows", "gas station", "logs", "smaller", "desert", "yield", "flags", "stool", "kitten", "doll", "daffodils", "letters", "dishwasher", "first base", "nuts", "2013", "persian", "swim trunks", "deep", "o", "doubles", "toothpicks", "in field", "wristband", "wheels", "baking", "4:15", "11:00", "ear", "2007", "51", "chevy", "using computer", "frog", "storm", "boogie board", "hungry", "by window", "ambulance", "pigtails", "audi", "microsoft", "on man", "cannot tell", "stained glass", "hugging", "laying down", "3:00", "taxi", "pedestrian", "landing", "numbers", "38", "stones", "on tree", "clocks", "new", "picnic", "fog", "buffalo", "under armour", "cocker spaniel", "orioles", "no sign", "telling time", "bags", "golden gate", "cover", "castle", "canoe", "selfie", "cream", "floating", "indoor", "antique", "aluminum", "silver and black", "cast iron", "peas", "sun hat", "on right", "swiss", "flour", "under sink", "fashion", "fedora", "shells", "1 hour", "puppy", "in stands", "not here", "motor", "thousands", "120", "sail", "butt", "mexican", "dead end", "paddle", "bathing suit", "shop", "onion rings", "boxing", "birthday cake", "chalk", "scenery", "style", "nissan", "sticker", "on rack", "1 4", "woman's", "surprised", "north face", "squash", "not sure", "email", "spotted", "seat", "himself", "circles", "san diego", "kia", "mattress", "obama", "lamb", "american flag", "climbing", "skull and crossbones", "roast beef", "visor", "herd", "double", "52", "high", "stagecoach", "cart", "feeding", "eaten", "cone", "11:15", "smoothie", "golf", "colorado", "electronics", "5:15", "bowling", "players", "ketchup and mustard", "styrofoam", "6 feet", "hawk", "cheddar", "12:28", "arabic", "12:25", "12:10", "shower curtain", "army", "salmon", "10:40", "hanging", "whole", "behind fence", "bars", "moss", "no dog", "traffic", "10:25", "r", "countryside", "machine", "directions", "cooked", "aa", "6:45", "4 way", "stripe", "brand", "baseball player", "bunk", "coleslaw", "fishing boat", "at table", "europe", "dead", "arch", "scrambled", "clothing", "closet", "egg", "suitcases", "indoors", "coffee pot", "tires", "lilies", "cafe", "9:35", "teal", "toothpaste", "in background", "tarmac", "painted", "sunset", "orange and yellow", "oar", "peaches", "zebra and giraffe", "ladybug", "20 ft", "sesame seeds", "hills", "2:30", "stucco", "tail", "couple", "kawasaki", "smooth", "powdered sugar", "pedestrian crossing", "french fries", "picnic table", "teeth", "ribbon", "saddle", "15 feet", "earbuds", "on train", "39", "curb", "tow", "shark", "white and orange", "6:25", "gravy", "fork and spoon", "pooping", "curtain", "lime", "skull", "crossing", "speed limit", "peacock", "boredom", "neck", "hit", "dragon", "tissues", "basil", "waving", "blue team", "rectangles", "helicopter", "mud", "us", "balcony", "red and gray", "firefighter", "sunflower", "wallpaper", "best buy", "11:20", "public market center", "seattle", "bookshelf", "looking", "1 inch", "harley", "urinal", "cartoon", "t shirt and jeans", "navy", "fedex", "rays", "deck", "coaster", "1:20", "50 feet", "4:20", "us open", "looking at camera", "600", "national express", "white house", "5:00", "jp morgan", "palm trees", "tub", "pens", "soldiers", "2 people", "animal", "speaker", "hamburger", "spaghetti", "green beans", "it isn't", "10:20", "buildings", "on shelf", "baseball uniform", "tiled", "orange and blue", "90", "north america", "arrow", "news", "tropicana", "formal", "in grass", "thumbs up", "clip", "gate", "tennis player", "lilac", "pastry", "nose", "pacifier", "11:35", "different teams", "cardinals", "exhaust", "hauling", "on tray", "bagel", "huge", "out of focus", "cook", "wheat", "photo", "ghost", "sedan", "qatar", "zig zag", "lanyard", "pink and white", "sesame", "space", "no clock", "warning", "snowy", "tater tots", "tropical", "grandfather", "mac", "magnet", "photoshop", "pajamas", "350", "casserole", "4:55", "pelican", "2009", "clydesdale", "tow truck", "belt", "west", "omelet", "heavy", "crown", "in corner", "hexagon", "mound", "iris", "g", "12:45", "2:15", "3:10", "drawing", "only", "little girl", "washing", "nokia", "windsor", "2 men", "parmesan cheese", "on woman", "freezer", "icing", "venice", "dairy", "several", "concentration", "3:15", "no smoking", "kayak", "frosting", "jetblue", "thoroughbred", "parakeet", "shoe", "skeleton", "britain", "ties", "in sink", "patio", "bank", "camouflage", "privacy", "bib", "blue and gray", "looking out window", "falling", "bucket", "cupcakes", "throw ball", "garden", "almonds", "ducati", "ireland", "plastic wrap", "starbucks", "all way", "bark", "home plate", "base", "dog food", "toys", "blue and orange", "1 in front", "foot", "dc", "california", "towing", "cheesecake", "bushes", "bow tie", "millions", "down street", "2011", "police officer", "windmill", "taking pictures", "street name", "cleaning", "on pole", "russia", "main street", "catch ball", "mario", "pirate", "track", "garage", "7:10", "they aren't", "mother and child", "tents", "fancy", "tattoos", "alcohol", "2:45", "wheelchair", "money", "top hat", "willow", "cd", "brushing hair", "pancake", "80", "listening to music", "green and red", "barrier", "vests", "hiking", "tank top", "lufthansa", "student", "menu", "forehand", "wii controllers", "acer", "wall st", "hundreds", "water ski", "furniture", "paisley", "pizza hut", "baseball game", "hill", "prom", "1 world", "tiara", "students", "information", "hazy", "nasa", "canon", "bird feeder", "crane", "dr pepper", "logitech", "2:10", "all of them", "utensils", "telephone", "converse", "bone", "jeep", "nursing", "krispy kreme", "cameraman", "pee", "ranch", "polka dots", "railroad crossing", "shirts", "feeder", "above toilet", "unclear", "below", "43", "spoons", "calendar", "vaio", "fox", "mint", "after", "spiderman", "lg", "concert", "on rock", "fluffy", "gray and black", "coats", "lady", "dodge", "easyjet", "pearl", "bunt", "flat screen", "10:30", "music", "polar bears", "riding horse", "lift", "angry", "cookies", "3:45", "buttons", "hot", "cute", "behind", "dole", "in motion", "26", "pans", "love", "winnie pooh", "pear", "copyright", "2 hours", "snowsuit", "kissing", "backhand", "to get to other side", "metro", "swans", "very fast", "can't see it", "nintendo", "direction", "waiting", "mohawk", "st patrick's day", "rail", "hoodie", "feet", "swirls", "muffins", "4:05", "106", "10:55", "coins", "mitt", "game controller", "room", "adults", "urinals", "cameras", "marker", "upright", "brass", "sled", "teacher", "conductor", "farmers market", "toiletries", "blue and black", "soccer field", "banana peel", "sprite", "doughnuts", "bank of america", "on his face", "heat", "emergency", "ski slope", "hard", "41", "6:00", "in his hand", "cluttered", "dog show", "on boat", "grizzly", "drums", "not", "in hand", "easy", "400", "under table", "d", "hitting ball", "photography", "intersection", "backwards", "crocs", "marina", "chips", "bible", "harry potter", "hawaii", "fanta", "half full", "carriage", "curious", "12:50", "black white", "geese", "pork", "mailbox", "l", "sidecar", "poop", "wings", "penguin", "to see", "pocket", "steps", "cubs", "junk", "deer", "ottoman", "salt", "condiments", "1:55", "post", "bulldog", "notebook", "no cat", "champagne", "jets", "knee pads", "throw frisbee", "drinks", "leopard", "taller", "cooler", "bundt", "monday", "grape", "wine tasting", "under", "baskets", "santa hat", "chest", "sewing", "on car", "sony ericsson", "peeing", "for photo", "tour", "few", "singapore", "fireman", "fire extinguisher", "wildebeest", "lemons", "peanuts", "babies", "wiimote", "guitar hero", "slide", "stopped", "library", "multi colored", "blue and pink", "choppy", "sailing", "brush", "grinding", "jelly", "dairy queen", "shaking hands", "ge", "tigers", "tokyo", "philadelphia", "ski boots", "buses", "11:45", "collage", "pink and blue", "jesus", "singles", "iron", "coffee table", "2 years", "don't walk", "classroom", "on water", "potato salad", "posts", "harbor", "residential", "joshua", "uk", "burgers", "deli", "kicking", "lace", "overalls", "vehicles", "ram", "dancing", "47", "shed", "lid", "he's not", "fans", "amtrak", "space shuttle", "ostrich", "bathtub", "kneeling", "2:50", "mall", "yellow and orange", "gazebo", "wax", "slow down", "lays", "hammer time", "octopus", "crib", "banana split", "broadway", "pottery", "wavy", "farmers", "holding phone", "on phone", "squirrel", "wax paper", "tusks", "dining", "packing", "kangaroo", "dawn", "defense", "powdered", "thomas", "budweiser", "back left", "stir fry", "beijing", "11:10", "tripod", "wide", "slope", "black and gray", "planter", "chili", "siblings", "kayaking", "captivity", "opaque", "rack", "panda", "doorway", "wheelie", "pelicans", "genetics", "not in service", "volvo", "dachshund", "v", "on laptop", "western", "gone", "birthday party", "parking garage", "tying tie", "blueberry", "scale", "notes", "train car", "man made", "stability", "lily", "lying down", "pacific", "high heels", "pare", "checkerboard", "partly cloudy", "cool", "n", "toilets", "tree branch", "copper", "cycling", "5:50", "870", "shopping", "7:05", "zipper", "holding umbrella", "batman", "lotion", "1:25", "black and brown", "playing video game", "girl on right", "legos", "drinking water", "burrito", "plow", "jet ski", "spiral", "ibm", "tools", "flashlight", "cherries", "maple leaf", "mountainous", "under tree", "vines", "sushi", "baker", "snake", "globe", "target", "john", "pomeranian", "tuxedo", "hockey", "sleeve", "leaning", "wireless", "11:05", "compaq", "do not enter", "radish", "1:05", "dim", "advertisement", "movement", "model", "hammock", "swing", "sheet", "google", "boardwalk", "right 1", "haircut", "ankle", "3:30", "exit", "csx", "tim hortons", "lego", "cucumbers", "angel", "12:20", "racquet", "behind woman", "potato", "egg salad", "controllers", "recliner", "upside down", "mosaic", "before", "antenna", "3:50", "10:15", "lion", "camo", "fighter", "silver and red", "dirt bike", "playing video games", "used", "crates", "horizontally", "plunger", "refrigerators", "radiator", "stork", "in basket", "cap", "living", "married", "briefcase", "bottom left", "30 mph", "ascending", "flip phone", "101", "11:50", "gun", "arizona", "foam", "serious", "y", "close up", "pancakes", "heineken", "paw", "cnn", "comforter", "sheets", "8:35", "driveway", "fair", "cleaner", "1 year", "delivery", "commuter", "apple and banana", "chase", "72", "safe", "trucks", "trunks", "spider", "64", "slacks", "meeting", "7:00", "skiers", "shaved", "carrot cake", "holding", "surfers", "giraffe and zebra", "7:45", "mississippi", "seaweed", "black and pink", "horse racing", "orchid", "rv", "tourist", "above door", "leaving", "pitch", "crest", "miami", "asics", "flood", "bus station", "take off", "amazon", "practice", "entering", "diesel", "pm", "wetsuits", "remodeling", "porch", "7:35", "tie dye", "baked", "life jacket", "cylinder", "grilled cheese", "meatballs", "paddling", "banana bread", "monster", "smiley face", "not high", "keys", "dreadlocks", "kitchenaid", "straight ahead", "badminton", "long sleeve", "sheepdog", "5:18", "end", "on shore", "scratching", "oriental", "5:05", "alligator", "city bus", "purple and white", "10:50", "each other", "weeds", "tinkerbell", "rottweiler", "apartments", "snowflakes", "stop light", "sweatshirt", "shore", "bidet", "switzerland", "stretching", "tv stand", "boundaries", "65", "bronze", "jar", "middle 1", "54", "skate", "easton", "turn right", "raspberries", "singing", "on bus", "carnations", "descending", "classic", "suspenders", "not long", "8:50", "father", "anniversary", "hsbc", "very long", "space needle", "skatepark", "fruit salad", "kenmore", "no water", "8:05", "db", "baby's breath", "shelter", "1980", "no left turn", "washington monument", "ham and cheese", "10 inches", "8:55", "savory", "6:35", "indians", "9:05", "fires", "pipes", "donkey", "cds", "mitsubishi", "tell time", "outfield", "christian", "puma", "parking meters", "cranes", "flip", "wine bottle", "stadium", "mouthwash", "heinz", "distance", "macaroni", "on plane", "triumph", "more", "4:50", "single engine", "disney", "on stove", "shih tzu", "fried", "to hit ball", "in her hand", "sunrise", "2nd", "elmo", "kite string", "suzuki", "traffic lights", "blt", "i", "hitting", "htc", "healthy", "current", "star alliance", "stomach", "watch tv", "tulip", "5:10", "right side", "4:40", "ginger", "on sign", "cushion", "5:30", "learning", "pencil", "maroon", "food processor", "5:40", "dog bed", "michigan", "close", "license plate", "crows", "right hand", "normal", "green and brown", "1.00", "000", "1:40", "wing", "american airlines", "kodak", "mural", "sniffing", "1:15", "behind bench", "cardinal", "no light", "warmth", "paved", "skyscrapers", "swinging bat", "watermark", "in cup", "pizza box", "dough", "hiding", "goal", "no plate", "shower head", "ripe", "1:10", "1 in back", "older", "nest", "multiple", "cinnamon", "bin", "new orleans", "colored", "enclosure", "bride", "on dresser", "star wars", "in back", "triangles", "over easy", "cilantro", "statues", "sticks", "formica", "roundabout", "bowls", "ahead", "years", "drain", "veggies", "no shirt", "taking photo", "tugboat", "broke", "59", "cadillac", "prince", "left side", "1 in middle", "10:45", "drying", "11:25", "silk", "conference room", "buoys", "pockets", "daffodil", "6:40", "walgreens", "4 ft", "6:05", "virgin atlantic", "12:40", "digital", "ups", "westjet", "bikers", "us air force", "limes", "comcast", "dip", "7:55", "man in middle", "bus driver", "soon", "futon", "selling", "braid", "mariners", "wisconsin", "99", "citizen", "broccoli and carrots", "grocery store", "us airways", "49", "bored", "red velvet", "hotel room", "qantas", "tam", "korean air", "10:35", "whirlpool", "coffee cup", "hilly", "9:12", "whipped cream", "video", "finger", "competition", "hollywood", "sas", "backward", "beads", "cosmo", "10:08", "jal", "6:30", "100 year party ct", "hispanic", "in cabbage town", "opponent", "woodpecker", "visilab", "mt airy", "crosstown", "freightliner"]
configs/retrieval.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hidden_size: &hidden_size 768
2
+ vocab_size: &vocab_size 30522
3
+ type_vocab_size: &type_vocab_size 2
4
+ max_position_embeddings: &max_position_embeddings 512
5
+ pad_token_id: &pad_token_id 0
6
+ embed_size: &embed_size 256
7
+
8
+ seed: 42
9
+ world_size: 1
10
+ device: "cuda"
11
+ dist_url: "env://"
12
+ output_path: "./examples/albef/outputs/retrieval_output.pt"
13
+
14
+ datamodule_args:
15
+ train_files: ["./examples/albef/data_files/coco_train.json"]
16
+ test_files: ["./examples/albef/data_files/coco_test.json"]
17
+ image_root: "./examples/albef/data_files/coco"
18
+ batch_size: 32
19
+ num_workers: 8
20
+
21
+ vision_encoder_args:
22
+ hidden_size: *hidden_size
23
+ image_size: 384
24
+ patch_size: 16
25
+ num_hidden_layers: 12
26
+ num_attention_heads: 12
27
+ mlp_dim: 3072
28
+ dropout: 0.0
29
+ attention_dropout: 0.0
30
+ layer_norm_eps: 1e-6
31
+
32
+ text_encoder_args:
33
+ vocab_size: *vocab_size
34
+ hidden_size: *hidden_size
35
+ type_vocab_size: *type_vocab_size
36
+ max_position_embeddings: *max_position_embeddings
37
+ pad_token_id: *pad_token_id
38
+ num_hidden_layers: 6
39
+ num_attention_heads: 12
40
+ intermediate_size: 3072
41
+ layer_norm_eps: 1e-12
42
+ dropout: 0.0
43
+
44
+ multimodal_encoder_args:
45
+ hidden_size: *hidden_size
46
+ num_hidden_layers: 6
47
+ num_attention_heads: 12
48
+ intermediate_size: 3072
49
+ layer_norm_eps: 1e-12
50
+
51
+ projection_args:
52
+ in_features: *hidden_size
53
+ out_features: *embed_size
54
+
55
+ similarity_args:
56
+ embed_size: *embed_size
57
+ queue_size: 65536
58
+ temp: 0.07
59
+
60
+ training_args:
61
+ log_every_n_steps: 100
62
+ alpha: 0.4
63
+ weight_decay: 0.02
64
+ lr: 1e-5
65
+ min_lr: 1e-6
66
+ max_epochs: 5
67
+ step_size: 100
68
+ warmup_steps: 1
69
+ checkpoint_root: "./examples/albef/checkpoints"
70
+
71
+ eval_args:
72
+ log_every_n_steps: 100
73
+ k_test: 256
configs/vqa.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hidden_size: &hidden_size 768
2
+ vocab_size: &vocab_size 30522
3
+ type_vocab_size: &type_vocab_size 2
4
+ max_position_embeddings: &max_position_embeddings 512
5
+ pad_token_id: &pad_token_id 0
6
+
7
+ seed: 42
8
+ world_size: 1
9
+ device: "cuda"
10
+ dist_url: "env://"
11
+ output_root: "./examples/albef/outputs"
12
+
13
+ datamodule_args:
14
+ train_files: ["./examples/albef/data_files/vqa_train.json", "./examples/albef/data_files/vg_qa.json", "./examples/albef/data_files/vqa_val.json"]
15
+ test_files: ["./examples/albef/data_files/vqa_test.json"]
16
+ answer_list: "./examples/albef/data_files/answer_list.json"
17
+ vqa_root: "./examples/albef/data_files/coco"
18
+ vg_root: "./examples/albef/data_files/visual_genome"
19
+ batch_size: 32
20
+ num_workers: 8
21
+
22
+ vision_encoder_args:
23
+ hidden_size: *hidden_size
24
+ image_size: 384
25
+ patch_size: 16
26
+ num_hidden_layers: 12
27
+ num_attention_heads: 12
28
+ mlp_dim: 3072
29
+ dropout: 0.0
30
+ attention_dropout: 0.0
31
+ layer_norm_eps: 1e-6
32
+
33
+ text_encoder_args:
34
+ vocab_size: *vocab_size
35
+ hidden_size: *hidden_size
36
+ type_vocab_size: *type_vocab_size
37
+ max_position_embeddings: *max_position_embeddings
38
+ pad_token_id: *pad_token_id
39
+ num_hidden_layers: 6
40
+ num_attention_heads: 12
41
+ intermediate_size: 3072
42
+ layer_norm_eps: 1e-12
43
+ dropout: 0.0
44
+
45
+ multimodal_encoder_args:
46
+ hidden_size: *hidden_size
47
+ num_hidden_layers: 6
48
+ num_attention_heads: 12
49
+ intermediate_size: 3072
50
+ layer_norm_eps: 1e-12
51
+
52
+ text_embeddings_args:
53
+ hidden_size: *hidden_size
54
+ vocab_size: *vocab_size
55
+ pad_token_id: *pad_token_id
56
+ max_position_embeddings: *max_position_embeddings
57
+ type_vocab_size: *type_vocab_size
58
+ layer_norm_eps: 1e-12
59
+
60
+ prediction_head_args:
61
+ hidden_size: *hidden_size
62
+ vocab_size: *vocab_size
63
+ layer_norm_eps: 1e-12
64
+
65
+ training_args:
66
+ log_every_n_steps: 100
67
+ alpha: 0.4
68
+ weight_decay: 0.02
69
+ lr: 2e-5
70
+ min_lr: 1e-6
71
+ max_epochs: 8
72
+ step_size: 100
73
+ warmup_steps: 4
74
+ checkpoint_root: "./examples/albef/checkpoints"
75
+
76
+ eval_args:
77
+ log_every_n_steps: 100
78
+ k_test: 128
data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
data/retrieval_datamodule.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple
8
+
9
+ import torch
10
+ from data.retrieval_dataset import (
11
+ ImageToTextRetrievalDataset,
12
+ RetrievalTrainingDataset,
13
+ TextToImageRetrievalDataset,
14
+ )
15
+ from data.transforms import (
16
+ ALBEFTextTransform,
17
+ testing_image_transform,
18
+ training_image_transform,
19
+ )
20
+ from pytorch_lightning import LightningDataModule
21
+ from torch import Tensor
22
+ from torch.nn.utils.rnn import pad_sequence
23
+ from torch.utils.data import DataLoader, Dataset, DistributedSampler
24
+
25
+
26
+ class RetrievalDataModule(LightningDataModule):
27
+ """
28
+ The Data Module for Retrieval task.
29
+
30
+ Args:
31
+ train_files (List[str]): The paths to training json files.
32
+ test_files (List[str]): The paths to testing json files.
33
+ image_root (str): The path to image data directory.
34
+ batch_size (int): The sampling batch size.
35
+ num_workers (int): The number of workers for the distributed mode.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ train_files: List[str],
41
+ test_files: List[str],
42
+ image_root: str,
43
+ batch_size: int,
44
+ num_workers: int,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.train_dataset = RetrievalTrainingDataset(
48
+ train_files,
49
+ image_root,
50
+ training_image_transform(),
51
+ ALBEFTextTransform(truncate=True, max_seq_len=30, add_end_token=False),
52
+ )
53
+
54
+ self.image_dataset = ImageToTextRetrievalDataset(
55
+ test_files,
56
+ image_root,
57
+ testing_image_transform(),
58
+ )
59
+
60
+ self.text_dataset = TextToImageRetrievalDataset(
61
+ test_files,
62
+ ALBEFTextTransform(
63
+ truncate=True,
64
+ pad_to_max_seq_len=True,
65
+ max_seq_len=30,
66
+ add_end_token=False,
67
+ ),
68
+ )
69
+
70
+ self.batch_size = batch_size
71
+ self.num_workers = num_workers
72
+
73
+ def _get_sampler(
74
+ self,
75
+ dataset: Dataset,
76
+ shuffle: bool,
77
+ is_distributed: bool,
78
+ num_tasks: int,
79
+ global_rank: int,
80
+ ) -> Optional[DistributedSampler]:
81
+ # do not return a sampler if is not in distributed mode
82
+ # a default RandomSampler is used in this case
83
+ if not is_distributed:
84
+ return None
85
+
86
+ return DistributedSampler(
87
+ dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
88
+ )
89
+
90
+ def train_dataloader(
91
+ self,
92
+ is_distributed: bool = False,
93
+ num_tasks: int = 0,
94
+ global_rank: int = 0,
95
+ drop_last: bool = True,
96
+ ) -> DataLoader:
97
+ """
98
+ DataLoader Outputs:
99
+ images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
100
+ text (Tensor): Tensor of shape (B, L) of text inputs.
101
+ text_atts (Tensor): Tensor of shape (B, L) of text attention mask.
102
+ idx (Tensor): Tensor of shape (B) of image identifiers.
103
+ """
104
+ sampler = self._get_sampler(
105
+ dataset=self.train_dataset,
106
+ shuffle=True,
107
+ is_distributed=is_distributed,
108
+ num_tasks=num_tasks,
109
+ global_rank=global_rank,
110
+ )
111
+ shuffle = sampler is None
112
+ return DataLoader(
113
+ self.train_dataset,
114
+ batch_size=self.batch_size,
115
+ num_workers=self.num_workers,
116
+ pin_memory=True,
117
+ sampler=sampler,
118
+ shuffle=shuffle,
119
+ collate_fn=retrieval_train_collate_fn,
120
+ drop_last=drop_last,
121
+ )
122
+
123
+ def image_dataloader(
124
+ self,
125
+ drop_last: bool = False,
126
+ ) -> DataLoader:
127
+ """
128
+ DataLoader Outputs:
129
+ images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
130
+ """
131
+ return DataLoader(
132
+ self.image_dataset,
133
+ batch_size=self.batch_size,
134
+ num_workers=self.num_workers,
135
+ pin_memory=True,
136
+ sampler=None,
137
+ shuffle=False,
138
+ collate_fn=None,
139
+ drop_last=drop_last,
140
+ )
141
+
142
+ def text_dataloader(
143
+ self,
144
+ drop_last: bool = False,
145
+ ) -> DataLoader:
146
+ """
147
+ DataLoader Outputs:
148
+ text (Tensor): Tensor of shape (B, L) of text inputs.
149
+ text_atts (Tensor): Tensor of shape (B, L) of text attention mask.
150
+ """
151
+ return DataLoader(
152
+ self.text_dataset,
153
+ batch_size=self.batch_size,
154
+ num_workers=self.num_workers,
155
+ pin_memory=True,
156
+ sampler=None,
157
+ shuffle=False,
158
+ collate_fn=text_collate_fn,
159
+ drop_last=drop_last,
160
+ )
161
+
162
+
163
+ def retrieval_train_collate_fn(
164
+ batch: List[Tuple[Tensor, Tensor, int]]
165
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
166
+ image_list = []
167
+ text_list = []
168
+ idx_list = []
169
+ for image, text, idx in batch:
170
+ image_list.append(image)
171
+ text_list.append(text)
172
+ idx_list.append(idx)
173
+ images = torch.stack(image_list, dim=0)
174
+ text = pad_sequence(text_list, batch_first=True)
175
+ text_atts = (text != 0).type(torch.long)
176
+ idx = Tensor(idx_list).type(torch.long)
177
+ return (
178
+ images,
179
+ text,
180
+ text_atts,
181
+ idx,
182
+ )
183
+
184
+
185
+ def text_collate_fn(batch: List[Tensor]) -> Tuple[Tensor, Tensor]:
186
+ text = pad_sequence(batch, batch_first=True)
187
+ text_atts = (text != 0).type(torch.long)
188
+ return text, text_atts
data/retrieval_dataset.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ from typing import Callable, List, Tuple, Union
10
+
11
+ from PIL import Image
12
+ from torch import Tensor
13
+ from torch.utils.data import Dataset
14
+
15
+
16
+ class RetrievalTrainingDataset(Dataset):
17
+ """
18
+ Create the training dataset for Retrieval task.
19
+
20
+ Args:
21
+ ann_file (List[str]): The paths to training annotation json files.
22
+ image_root (str): The path to image data directory.
23
+ image_transform (Callable[[Image.Image], Tensor]): Image data transform.
24
+ text_transform (Callable[[Union[List[str], str]], Tensor]): Text data transform.
25
+
26
+ Dataset Outputs:
27
+ image (Tensor): Transformed image input tensor of shape (C, H, W).
28
+ caption (Tensor): Transformed text token input ids.
29
+ idx (int): The unique identifier for the image.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ ann_file: List[str],
35
+ image_root: str,
36
+ image_transform: Callable[[Image.Image], Tensor],
37
+ text_transform: Callable[[Union[List[str], str]], Tensor],
38
+ ) -> None:
39
+ self.ann = []
40
+ for f in ann_file:
41
+ self.ann += json.load(open(f, "r"))
42
+
43
+ self.image_root = image_root
44
+ self.image_transform = image_transform
45
+ self.text_transform = text_transform
46
+
47
+ self.idx = {} # map str image_id from dataset to int ids
48
+ i = 0
49
+ for ann in self.ann:
50
+ image_id = ann["image_id"]
51
+ if image_id not in self.idx.keys():
52
+ self.idx[image_id] = i
53
+ i += 1
54
+
55
+ def __len__(self) -> int:
56
+ return len(self.ann)
57
+
58
+ def __getitem__(self, index: int) -> Tuple[Tensor, Tensor, int]:
59
+ ann = self.ann[index]
60
+ image_path = os.path.join(self.image_root, ann["image"])
61
+ image = Image.open(image_path).convert("RGB")
62
+ image = self.image_transform(image)
63
+ caption = self.text_transform(ann["caption"])
64
+ return image, caption, self.idx[ann["image_id"]]
65
+
66
+
67
+ class ImageToTextRetrievalDataset(Dataset):
68
+ """
69
+ Create the dataset for Image-to-Text Retrieval task.
70
+
71
+ Args:
72
+ ann_file (List[str]): The paths to annotation json files.
73
+ image_root (str): The path to image data directory.
74
+ image_transform (Callable[[Image.Image], Tensor]): Image data transform.
75
+
76
+ Dataset Outputs:
77
+ image (Tensor): Transformed image input tensor of shape (C, H, W).
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ ann_file: List[str],
83
+ image_root: str,
84
+ image_transform: Callable[[Image.Image], Tensor],
85
+ ) -> None:
86
+ self.image_root = image_root
87
+ self.image_transform = image_transform
88
+
89
+ self.ann = []
90
+ self.images = [] # paths to all images in the dataset
91
+ self.image_to_text = {} # map image ids to text ids for evaluation
92
+ for f in ann_file:
93
+ self.ann += json.load(open(f, "r"))
94
+
95
+ text_id = 0
96
+ for image_id, ann in enumerate(self.ann):
97
+ self.images.append(ann["image"])
98
+ num_text = len(ann["caption"])
99
+ self.image_to_text[image_id] = list(range(text_id, text_id + num_text))
100
+ text_id += num_text
101
+
102
+ def __len__(self) -> int:
103
+ return len(self.images)
104
+
105
+ def __getitem__(self, index: int) -> Tensor:
106
+ image_path = os.path.join(self.image_root, self.images[index])
107
+ image = Image.open(image_path).convert("RGB")
108
+ image = self.image_transform(image)
109
+ return image
110
+
111
+
112
+ class TextToImageRetrievalDataset(Dataset):
113
+ """
114
+ Create the dataset for Text-to-Image Retrieval task.
115
+
116
+ Args:
117
+ ann_file (List[str]): The paths to annotation json files.
118
+ text_transform (Callable[[Union[List[str], str]], Tensor]): Text data transform.
119
+
120
+ Dataset Outputs:
121
+ text (Tensor): Transformed text token input ids.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ ann_file: List[str],
127
+ text_transform: Callable[[Union[List[str], str]], Tensor],
128
+ ) -> None:
129
+ self.text_transform = text_transform
130
+
131
+ self.ann = []
132
+ self.text = [] # all text strings in the dataset
133
+ self.text_to_image = {} # map text ids to image ids for evaluation
134
+ for f in ann_file:
135
+ self.ann += json.load(open(f, "r"))
136
+
137
+ text_id = 0
138
+ for image_id, ann in enumerate(self.ann):
139
+ for caption in ann["caption"]:
140
+ self.text.append(caption)
141
+ self.text_to_image[text_id] = image_id
142
+ text_id += 1
143
+
144
+ def __len__(self) -> int:
145
+ return len(self.text)
146
+
147
+ def __getitem__(self, index: int) -> Tensor:
148
+ text = self.text_transform(self.text[index])
149
+ return text
data/transforms.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import re
8
+ from typing import List, Tuple, Union
9
+
10
+ import torch
11
+
12
+ from torchtext.transforms import PadTransform, Sequential, ToTensor, Truncate
13
+ from torchvision import transforms
14
+ from transformers.models.bert.tokenization_bert import BertTokenizer
15
+
16
+ # mean and standard deviation from the ALBEF repo:
17
+ # https://github.com/salesforce/ALBEF/blob/main/dataset/__init__.py#L16
18
+ MEAN = (0.48145466, 0.4578275, 0.40821073)
19
+ STD_DEV = (0.26862954, 0.26130258, 0.27577711)
20
+
21
+
22
+ class ALBEFTextTransform:
23
+ """
24
+ Remove punctuations and trailing spaces in input text and transform it into
25
+ a Tensor of token ids using BERTTokenizer.
26
+
27
+ Args:
28
+ pretrained_tokenizer (str): Pretrained tokenizer to use.
29
+ Default: "bert-base-uncased"
30
+ do_pre_process (bool): Whether to pre-process input text.
31
+ Defaults to True.
32
+ truncate (bool): Whether to truncate input text to max_seq_length.
33
+ Defaults to False.
34
+ pad_to_max_seq_len (bool): Whether to pad the sequence to max_seq_length.
35
+ add_end_token (bool): Whether to add the end-of-sentence token.
36
+ Defaults to True.
37
+ max_seq_len (int): The max sequence length after truncating or padding.
38
+ Defaults to 25.
39
+ cls_token_id (int): Value to represent the start of each text.
40
+ Defaults to 101, Hugging Face's BERT cls token id.
41
+ sep_token_id (int): Value to represent the end of each text.
42
+ Defaults to 102, Hugging Face's BERT sep token id.
43
+ pad_token_id (int): Value with which to pad each text so that all texts are the same length.
44
+ Defaults to 0, Hugging Face's BERT pad token id.
45
+
46
+ Inputs:
47
+ text (Union[List[str], str]): Input text to transform.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ pretrained_tokenizer: str = "bert-base-uncased",
53
+ do_pre_process: bool = True,
54
+ truncate: bool = False,
55
+ pad_to_max_seq_len: bool = False,
56
+ add_end_token: bool = True,
57
+ max_seq_len: int = 25,
58
+ cls_token_id: int = 101,
59
+ sep_token_id: int = 102,
60
+ pad_token_id: int = 0,
61
+ ):
62
+ self.do_pre_process = do_pre_process
63
+ self.cls_token_id = cls_token_id
64
+ self.sep_token_id = sep_token_id
65
+ self.pad_token_id = pad_token_id
66
+ self.add_end_token = add_end_token
67
+
68
+ self.tokenizer = BertTokenizer.from_pretrained(pretrained_tokenizer)
69
+ self.transform = Sequential(
70
+ Truncate(max_seq_len=max_seq_len) if truncate else torch.nn.Identity(),
71
+ ToTensor(padding_value=self.pad_token_id),
72
+ PadTransform(max_length=max_seq_len, pad_value=self.pad_token_id)
73
+ if pad_to_max_seq_len
74
+ else torch.nn.Identity(),
75
+ )
76
+
77
+ def pre_process(self, text: str) -> str:
78
+ text = (
79
+ re.sub(
80
+ r"([,.'!?\"()*#:;~])",
81
+ "",
82
+ text,
83
+ )
84
+ .replace("-", " ")
85
+ .replace("/", " ")
86
+ )
87
+ text = text.rstrip(" ")
88
+
89
+ return text
90
+
91
+ def __call__(self, text: Union[List[str], str]) -> torch.Tensor:
92
+ if self.do_pre_process:
93
+ if isinstance(text, str):
94
+ text = self.pre_process(text)
95
+ else:
96
+ text = [self.pre_process(t) for t in text]
97
+ tokens = self.tokenizer(text)["input_ids"]
98
+ if not self.add_end_token and tokens[-1] == self.sep_token_id:
99
+ tokens = tokens[:-1]
100
+ input_ids = self.transform(tokens)
101
+
102
+ return input_ids
103
+
104
+
105
+ def training_image_transform(
106
+ image_size: int = 384,
107
+ scale: Tuple[float, float] = (0.5, 1.0),
108
+ image_interpolation=transforms.InterpolationMode.BICUBIC,
109
+ mean: Tuple[float, float, float] = MEAN,
110
+ std_dev: Tuple[float, float, float] = STD_DEV,
111
+ ) -> transforms.Compose:
112
+ return transforms.Compose(
113
+ [
114
+ transforms.RandomResizedCrop(
115
+ image_size, scale=scale, interpolation=image_interpolation
116
+ ),
117
+ transforms.RandomHorizontalFlip(),
118
+ transforms.RandAugment(2, 7),
119
+ transforms.ToTensor(),
120
+ transforms.Normalize(mean, std_dev),
121
+ ]
122
+ )
123
+
124
+
125
+ def testing_image_transform(
126
+ image_size: int = 384,
127
+ image_interpolation=transforms.InterpolationMode.BICUBIC,
128
+ mean: Tuple[float, float, float] = MEAN,
129
+ std_dev: Tuple[float, float, float] = STD_DEV,
130
+ ) -> transforms.Compose:
131
+ return transforms.Compose(
132
+ [
133
+ transforms.Resize(
134
+ (image_size, image_size), interpolation=image_interpolation
135
+ ),
136
+ transforms.ToTensor(),
137
+ transforms.Normalize(mean, std_dev),
138
+ ]
139
+ )
data/vqa_datamodules.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Optional, Tuple
8
+
9
+ import torch
10
+ from data.transforms import (
11
+ ALBEFTextTransform,
12
+ testing_image_transform,
13
+ training_image_transform,
14
+ )
15
+ from data.vqa_dataset import VQADataset
16
+ from pytorch_lightning import LightningDataModule
17
+ from torch import Tensor
18
+ from torch.nn.utils.rnn import pad_sequence
19
+ from torch.utils.data import DataLoader, DistributedSampler
20
+
21
+
22
+ class VQADataModule(LightningDataModule):
23
+ """
24
+ The Data Module for Visual Question Answering task.
25
+
26
+ Args:
27
+ train_files (List[str]): The paths to training json files.
28
+ test_files (List[str]): The paths to testing json files.
29
+ answer_list (str): The path to the answers list.
30
+ vqa_root (str): The path to vqa data directory.
31
+ vg_root (str): The path to vg data directory.
32
+ batch_size (int): The sampling batch size.
33
+ num_workers (int): The number of workers for the distributed mode.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ train_files: List[str],
39
+ test_files: List[str],
40
+ answer_list: str,
41
+ vqa_root: str,
42
+ vg_root: str,
43
+ batch_size: int,
44
+ num_workers: int,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.train_dataset = VQADataset(
48
+ train_files,
49
+ vqa_root,
50
+ vg_root,
51
+ image_transform=training_image_transform(),
52
+ question_transform=ALBEFTextTransform(
53
+ truncate=True, max_seq_len=25, add_end_token=False
54
+ ),
55
+ answer_transform=ALBEFTextTransform(do_pre_process=False),
56
+ split="train",
57
+ )
58
+
59
+ self.test_dataset = VQADataset(
60
+ test_files,
61
+ vqa_root,
62
+ vg_root,
63
+ image_transform=testing_image_transform(),
64
+ question_transform=ALBEFTextTransform(add_end_token=False),
65
+ answer_transform=ALBEFTextTransform(do_pre_process=False),
66
+ split="test",
67
+ answer_list=answer_list,
68
+ )
69
+
70
+ self.batch_size = batch_size
71
+ self.num_workers = num_workers
72
+
73
+ def _get_sampler(
74
+ self,
75
+ dataset: VQADataset,
76
+ shuffle: bool,
77
+ is_distributed: bool,
78
+ num_tasks: int,
79
+ global_rank: int,
80
+ ) -> Optional[DistributedSampler]:
81
+ if not is_distributed:
82
+ return None
83
+
84
+ return DistributedSampler(
85
+ dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
86
+ )
87
+
88
+ def train_dataloader(
89
+ self,
90
+ is_distributed: bool = False,
91
+ num_tasks: int = 0,
92
+ global_rank: int = 0,
93
+ drop_last: bool = True,
94
+ ) -> DataLoader:
95
+ """
96
+ DataLoader Outputs:
97
+ images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
98
+ questions (Tensor): Tensor of shape (B, L) of question inputs.
99
+ question_atts (Tensor): Tensor of shape (B, L) of question attention mask.
100
+ answers (Tensor): Tensor of shape (N, M) of answer inputs.
101
+ N >= B because a vqa sample can have multiple answers.
102
+ answer_atts (Tensor): Tensor of shape (N, M) of answer attention mask.
103
+ weights (Tensor): Tensor of shape (N) of answer weights.
104
+ ans_lengths (List[int]): List of length B and sum N where
105
+ ans_lengths[i] = number of answers for images[i] and questions[i].
106
+ """
107
+ sampler = self._get_sampler(
108
+ dataset=self.train_dataset,
109
+ shuffle=True,
110
+ is_distributed=is_distributed,
111
+ num_tasks=num_tasks,
112
+ global_rank=global_rank,
113
+ )
114
+ shuffle = sampler is None
115
+ return DataLoader(
116
+ self.train_dataset,
117
+ batch_size=self.batch_size,
118
+ num_workers=self.num_workers,
119
+ pin_memory=True,
120
+ sampler=sampler,
121
+ shuffle=shuffle,
122
+ collate_fn=vqa_train_collate_fn,
123
+ drop_last=drop_last,
124
+ )
125
+
126
+ def test_dataloader(
127
+ self,
128
+ is_distributed: bool = False,
129
+ num_tasks: int = 0,
130
+ global_rank: int = 0,
131
+ drop_last=False,
132
+ ) -> DataLoader:
133
+ """
134
+ DataLoader Outputs:
135
+ images (Tensor): Tensor of shape (B, C, W, H) of image inputs.
136
+ questions (Tensor): Tensor of shape (B, L) of question inputs.
137
+ question_atts (Tensor): Tensor of shape (B, L) of question attention mask.
138
+ question_ids (List): List of length B of question ids.
139
+ """
140
+ sampler = self._get_sampler(
141
+ dataset=self.test_dataset,
142
+ shuffle=False,
143
+ is_distributed=is_distributed,
144
+ num_tasks=num_tasks,
145
+ global_rank=global_rank,
146
+ )
147
+ return DataLoader(
148
+ self.test_dataset,
149
+ batch_size=self.batch_size,
150
+ num_workers=self.num_workers,
151
+ pin_memory=True,
152
+ sampler=sampler,
153
+ shuffle=False,
154
+ collate_fn=vqa_test_collate_fn,
155
+ drop_last=drop_last,
156
+ )
157
+
158
+
159
+ def vqa_train_collate_fn(
160
+ batch: List[Tuple[Tensor, Tensor, List[Tensor], List[float]]]
161
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, List[int]]:
162
+ image_list = []
163
+ question_list = []
164
+ answer_list = []
165
+ weight_list = []
166
+ ans_lengths = []
167
+ for image, question, answer, weights in batch:
168
+ image_list.append(image)
169
+ question_list.append(question)
170
+ answer_list += answer
171
+ weight_list += weights
172
+ ans_lengths.append(len(answer))
173
+ images = torch.stack(image_list, dim=0)
174
+ questions = pad_sequence(question_list, batch_first=True)
175
+ question_atts = (questions != 0).type(torch.long)
176
+ answers = pad_sequence(answer_list, batch_first=True)
177
+ answer_atts = (answers != 0).type(torch.long)
178
+ weights = torch.Tensor(weight_list)
179
+ return (
180
+ images,
181
+ questions,
182
+ question_atts,
183
+ answers,
184
+ answer_atts,
185
+ weights,
186
+ ans_lengths,
187
+ )
188
+
189
+
190
+ def vqa_test_collate_fn(
191
+ batch: List[Tuple[Tensor, Tensor, int]]
192
+ ) -> Tuple[Tensor, Tensor, Tensor, List[int]]:
193
+ image_list, question_list, question_ids = [], [], []
194
+ for image, question, question_id in batch:
195
+ image_list.append(image)
196
+ question_list.append(question)
197
+ question_ids.append(question_id)
198
+ images = torch.stack(image_list, dim=0)
199
+ questions = pad_sequence(question_list, batch_first=True)
200
+ question_atts = (questions != 0).type(torch.long)
201
+ return (
202
+ images,
203
+ questions,
204
+ question_atts,
205
+ question_ids,
206
+ )
data/vqa_dataset.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ from typing import Callable, List, Tuple, Union
10
+
11
+ import torch
12
+
13
+ from PIL import Image
14
+ from torch import Tensor
15
+ from torch.utils.data import Dataset
16
+
17
+
18
+ class VQADataset(Dataset):
19
+ """
20
+ Create the dataset for VQA task.
21
+
22
+ Args:
23
+ ann_file (List[str]): The paths to annotation json files.
24
+ vqa_root (str): The path to vqa data directory.
25
+ vg_root (str): The path to vg data directory.
26
+ image_transform (Callable[[Image.Image], Tensor]): image data transform.
27
+ question_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for questions.
28
+ answer_transform (Callable[[Union[List[str], str]], Tensor]): text data transform for answers.
29
+ split (str): Indicates train or test. Default is train.
30
+ answer_list (str): The path to the answers list. Required for test split.
31
+
32
+ Dataset Outputs:
33
+ if split is train:
34
+ image (Tensor): Transformed image input tensor of shape (C, W, H).
35
+ question (Tensor): Transformed question token input ids.
36
+ answers (List[Tensor]): List of transformed answers token input ids.
37
+ answer_weights (List[float]): List of answer weights.
38
+ answer_weights[i] is proportional to the number of occurences of answers[i]
39
+ if split is test:
40
+ image (Tensor): Transformed image input tensor of shape (C, W, H).
41
+ question (Tensor): Transformed text token input ids.
42
+ question_id (int): The question sample id.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ ann_file: List[str],
48
+ vqa_root: str,
49
+ vg_root: str,
50
+ image_transform: Callable[[Image.Image], Tensor],
51
+ question_transform: Callable[[Union[List[str], str]], Tensor],
52
+ answer_transform: Callable[[Union[List[str], str]], Tensor],
53
+ split: str = "train",
54
+ answer_list: str = None,
55
+ ) -> None:
56
+ self.ann = []
57
+ for f in ann_file:
58
+ self.ann += json.load(open(f, "r"))
59
+
60
+ self.vqa_root = vqa_root
61
+ self.vg_root = vg_root
62
+ self.image_transform = image_transform
63
+ self.question_transform = question_transform
64
+ self.answer_transform = answer_transform
65
+ self.split = split
66
+
67
+ if split == "test":
68
+ self.answer_list = json.load(open(answer_list, "r"))
69
+ self.answer_input_ids = self.answer_transform(self.answer_list)
70
+ self.answer_attention_mask = (self.answer_input_ids != 0).type(torch.long)
71
+
72
+ def __len__(self) -> int:
73
+ return len(self.ann)
74
+
75
+ def __getitem__(
76
+ self, index: int
77
+ ) -> Union[
78
+ Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, List[Tensor], List[float]]
79
+ ]:
80
+ ann = self.ann[index]
81
+
82
+ image_root = self.vqa_root if ann["dataset"] == "vqa" else self.vg_root
83
+ image_path = os.path.join(image_root, ann["image"])
84
+ image = Image.open(image_path).convert("RGB")
85
+ image = self.image_transform(image)
86
+ question = self.question_transform(ann["question"])
87
+
88
+ if self.split == "test":
89
+ return image, question, ann["question_id"]
90
+
91
+ elif self.split == "train":
92
+ if ann["dataset"] == "vqa":
93
+ # Each VQA sample question has a list of answers (with potential repeats)
94
+ # answer_weight[answer] = count(answer) / len(answers for the question)
95
+ answer_weights = {}
96
+ for answer in ann["answer"]:
97
+ if answer in answer_weights.keys():
98
+ answer_weights[answer] += 1 / len(ann["answer"])
99
+ else:
100
+ answer_weights[answer] = 1 / len(ann["answer"])
101
+
102
+ answers = list(answer_weights.keys())
103
+ answer_weights = list(answer_weights.values())
104
+
105
+ elif ann["dataset"] == "vg":
106
+ # A VG sample question has one answer so assign it a constant weight (0.5)
107
+ answers = [ann["answer"]]
108
+ answer_weights = [0.5]
109
+
110
+ answers = list(self.answer_transform(answers))
111
+
112
+ return image, question, answers, answer_weights
113
+
114
+ else:
115
+ raise ValueError("dataset split should be train or test")
finetune_retrieval.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import datetime
9
+ import os
10
+ import random
11
+ import time
12
+
13
+ import ruamel.yaml as yaml
14
+ import torch
15
+ import torch.backends.cudnn as cudnn
16
+ import torch.distributed as dist
17
+ from data.retrieval_datamodule import RetrievalDataModule
18
+ from model import albef_model_for_retrieval
19
+ from torch.optim import AdamW
20
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
21
+ from utils import (
22
+ add_weight_decay,
23
+ get_rank,
24
+ get_world_size,
25
+ init_distributed_mode,
26
+ is_dist_avail_and_initialized,
27
+ is_main_process,
28
+ )
29
+
30
+
31
+ def train(model, datamodule, args, device):
32
+ model.train()
33
+
34
+ model_without_ddp = model.module if is_dist_avail_and_initialized() else model
35
+
36
+ optimizer_params = add_weight_decay(model, args["weight_decay"])
37
+ optimizer = AdamW(optimizer_params, lr=args["lr"])
38
+ scheduler = CosineAnnealingWarmRestarts(
39
+ optimizer, T_0=args["max_epochs"], eta_min=args["min_lr"]
40
+ )
41
+
42
+ step_size = args["step_size"]
43
+ warmup_steps = args["warmup_steps"]
44
+ warmup_iterations = warmup_steps * step_size
45
+
46
+ data_loader = datamodule.train_dataloader(
47
+ is_distributed=is_dist_avail_and_initialized(),
48
+ num_tasks=get_world_size(),
49
+ global_rank=get_rank(),
50
+ )
51
+
52
+ start_time = time.time()
53
+
54
+ for epoch in range(args["max_epochs"]):
55
+ if epoch > 0:
56
+ scheduler.step(epoch + warmup_steps)
57
+
58
+ for batch, (image, text, text_atts, idx) in enumerate(data_loader):
59
+ if epoch > 0:
60
+ alpha = args["alpha"]
61
+ else:
62
+ alpha = args["alpha"] * min(1, batch / len(data_loader))
63
+
64
+ image = image.to(device, non_blocking=True)
65
+ text = text.to(device)
66
+ text_atts = text_atts.to(device)
67
+ idx = idx.to(device, non_blocking=True)
68
+ loss = model(image, text, text_atts, idx, alpha, is_train=True)
69
+
70
+ optimizer.zero_grad()
71
+ loss.backward()
72
+ optimizer.step()
73
+
74
+ if epoch == 0 and batch % step_size == 0 and batch <= warmup_iterations:
75
+ scheduler.step(batch // step_size)
76
+
77
+ if batch % args["log_every_n_steps"] == 0:
78
+ total_time = time.time() - start_time
79
+ time_str = "time {},".format(
80
+ datetime.timedelta(seconds=int(total_time))
81
+ )
82
+ epoch_str = "epoch {}/{},".format(epoch, args["max_epochs"])
83
+ batch_str = "batch {}/{},".format(batch, len(data_loader))
84
+ loss_str = "loss {}".format(loss.item())
85
+ print(time_str, epoch_str, batch_str, loss_str)
86
+
87
+ if is_main_process():
88
+ save_obj = {
89
+ "model": model_without_ddp.state_dict(),
90
+ "optimizer": optimizer.state_dict(),
91
+ "lr_scheduler": scheduler.state_dict(),
92
+ "epoch": epoch,
93
+ }
94
+ torch.save(
95
+ save_obj,
96
+ os.path.join(
97
+ args["checkpoint_root"], "retrieval_checkpoint_%02d.pt" % epoch
98
+ ),
99
+ )
100
+
101
+ if is_dist_avail_and_initialized():
102
+ dist.barrier()
103
+ torch.cuda.empty_cache()
104
+
105
+
106
+ @torch.no_grad()
107
+ def encode_text(model, text_dataloader, device):
108
+ text_embeds = []
109
+ text_feats = []
110
+ text_atts = []
111
+ for text, text_att in text_dataloader:
112
+ text = text.to(device)
113
+ text_att = text_att.to(device)
114
+ text_embed, text_feat = model(
115
+ text=text, text_atts=text_att, input_type="text", is_train=False
116
+ )
117
+ text_embeds.append(text_embed)
118
+ text_feats.append(text_feat)
119
+ text_atts.append(text_att)
120
+ text_embeds = torch.cat(text_embeds, dim=0)
121
+ text_feats = torch.cat(text_feats, dim=0)
122
+ text_atts = torch.cat(text_atts, dim=0)
123
+ return text_embeds, text_feats, text_atts
124
+
125
+
126
+ @torch.no_grad()
127
+ def encode_image(model, image_dataloader, device):
128
+ image_embeds = []
129
+ image_feats = []
130
+ for image in image_dataloader:
131
+ image = image.to(device)
132
+ image_embed, image_feat = model(image=image, input_type="image", is_train=False)
133
+ image_embeds.append(image_embed)
134
+ image_feats.append(image_feat)
135
+ image_embeds = torch.cat(image_embeds, dim=0)
136
+ image_feats = torch.cat(image_feats, dim=0)
137
+ return image_embeds, image_feats
138
+
139
+
140
+ @torch.no_grad()
141
+ def image_to_text(
142
+ model,
143
+ image_embeds,
144
+ text_embeds,
145
+ text_atts,
146
+ sims_matrix,
147
+ num_images,
148
+ num_text,
149
+ device,
150
+ args,
151
+ ):
152
+ start_time = time.time()
153
+ world_size = get_world_size()
154
+ rank = get_rank()
155
+ step = sims_matrix.size(0) // world_size + 1
156
+ start = rank * step
157
+ end = min(sims_matrix.size(0), start + step)
158
+ k = args["k_test"]
159
+
160
+ image_to_text_scores = torch.full((num_images, num_text), -100.0).to(device)
161
+ for i, sims in enumerate(sims_matrix[start:end]):
162
+ _, topk_idx = sims.topk(k, dim=0)
163
+
164
+ score = model(
165
+ image=image_embeds[start + i].repeat(k, 1, 1),
166
+ text=text_embeds[topk_idx],
167
+ text_atts=text_atts[topk_idx],
168
+ input_type="multimodal",
169
+ is_train=False,
170
+ )
171
+ image_to_text_scores[start + i, topk_idx] = score
172
+
173
+ if i % args["log_every_n_steps"] == 0:
174
+ total_time = time.time() - start_time
175
+ time_str = "time {},".format(datetime.timedelta(seconds=int(total_time)))
176
+ batch_str = "batch {}/{},".format(i, len(sims_matrix[start:end]))
177
+ print("image to text retrieval", time_str, batch_str)
178
+ return image_to_text_scores
179
+
180
+
181
+ @torch.no_grad()
182
+ def text_to_image(
183
+ model,
184
+ image_embeds,
185
+ text_embeds,
186
+ text_atts,
187
+ sims_matrix,
188
+ num_images,
189
+ num_text,
190
+ device,
191
+ args,
192
+ ):
193
+ start_time = time.time()
194
+ world_size = get_world_size()
195
+ rank = get_rank()
196
+ step = sims_matrix.size(0) // world_size + 1
197
+ start = rank * step
198
+ end = min(sims_matrix.size(0), start + step)
199
+ k = args["k_test"]
200
+
201
+ text_to_image_scores = torch.full((num_text, num_images), -100.0).to(device)
202
+ for i, sims in enumerate(sims_matrix[start:end]):
203
+ _, topk_idx = sims.topk(k, dim=0)
204
+ score = model(
205
+ image=image_embeds[topk_idx],
206
+ text=text_embeds[start + i].repeat(k, 1, 1),
207
+ text_atts=text_atts[start + i].repeat(k, 1, 1),
208
+ input_type="multimodal",
209
+ is_train=False,
210
+ )
211
+ text_to_image_scores[start + i, topk_idx] = score
212
+
213
+ if i % args["log_every_n_steps"] == 0:
214
+ total_time = time.time() - start_time
215
+ time_str = "time {},".format(datetime.timedelta(seconds=int(total_time)))
216
+ batch_str = "batch {}/{},".format(i, len(sims_matrix[start:end]))
217
+ print("text to image retrieval", time_str, batch_str)
218
+ return text_to_image_scores
219
+
220
+
221
+ @torch.no_grad()
222
+ def evaluation(model, datamodule, args, device):
223
+ model.eval()
224
+
225
+ text_loader = datamodule.text_dataloader()
226
+ image_loader = datamodule.image_dataloader()
227
+ num_images = len(datamodule.image_dataset)
228
+ num_text = len(datamodule.text_dataset)
229
+
230
+ text_embeds, text_feats, text_atts = encode_text(model, text_loader, device)
231
+ image_embeds, image_feats = encode_image(model, image_loader, device)
232
+
233
+ sims_matrix = image_feats @ text_feats.t()
234
+ image_to_text_scores = image_to_text(
235
+ model,
236
+ image_embeds,
237
+ text_embeds,
238
+ text_atts,
239
+ sims_matrix,
240
+ num_images,
241
+ num_text,
242
+ device,
243
+ args,
244
+ )
245
+
246
+ sims_matrix = sims_matrix.t()
247
+ text_to_image_scores = text_to_image(
248
+ model,
249
+ image_embeds,
250
+ text_embeds,
251
+ text_atts,
252
+ sims_matrix,
253
+ num_images,
254
+ num_text,
255
+ device,
256
+ args,
257
+ )
258
+
259
+ if is_dist_avail_and_initialized():
260
+ dist.barrier()
261
+ torch.distributed.all_reduce(
262
+ image_to_text_scores, op=torch.distributed.ReduceOp.SUM
263
+ )
264
+ torch.distributed.all_reduce(
265
+ text_to_image_scores, op=torch.distributed.ReduceOp.SUM
266
+ )
267
+
268
+ return image_to_text_scores.cpu(), text_to_image_scores.cpu()
269
+
270
+
271
+ @torch.no_grad()
272
+ def itm_eval(
273
+ image_to_text_scores,
274
+ text_to_image_scores,
275
+ image_to_text_mapping,
276
+ text_to_image_mapping,
277
+ ):
278
+ # Images to Text
279
+ ranks = torch.zeros(image_to_text_scores.size(0))
280
+ for index, score in enumerate(image_to_text_scores):
281
+ inds = torch.flip(torch.argsort(score), dims=[0])
282
+ rank = 1e10
283
+ # each image has multiple text mappings
284
+ # check retrieved inds with each ground truth mappping i
285
+ for i in image_to_text_mapping[index]:
286
+ tmp = torch.where(inds == i)[0][0]
287
+ if tmp < rank:
288
+ rank = tmp
289
+ ranks[index] = rank
290
+
291
+ # Compute metrics
292
+ tr1 = 100.0 * len(torch.where(ranks < 1)[0]) / len(ranks)
293
+ tr5 = 100.0 * len(torch.where(ranks < 5)[0]) / len(ranks)
294
+ tr10 = 100.0 * len(torch.where(ranks < 10)[0]) / len(ranks)
295
+
296
+ # Text to Images
297
+ ranks = torch.zeros(text_to_image_scores.size(0))
298
+ for index, score in enumerate(text_to_image_scores):
299
+ inds = torch.flip(torch.argsort(score), dims=[0])
300
+ ranks[index] = torch.where(inds == text_to_image_mapping[index])[0][0]
301
+
302
+ # Compute metrics
303
+ ir1 = 100.0 * len(torch.where(ranks < 1)[0]) / len(ranks)
304
+ ir5 = 100.0 * len(torch.where(ranks < 5)[0]) / len(ranks)
305
+ ir10 = 100.0 * len(torch.where(ranks < 10)[0]) / len(ranks)
306
+
307
+ tr_mean = (tr1 + tr5 + tr10) / 3
308
+ ir_mean = (ir1 + ir5 + ir10) / 3
309
+ r_mean = (tr_mean + ir_mean) / 2
310
+
311
+ eval_result = {
312
+ "txt_r1": tr1,
313
+ "txt_r5": tr5,
314
+ "txt_r10": tr10,
315
+ "txt_r_mean": tr_mean,
316
+ "img_r1": ir1,
317
+ "img_r5": ir5,
318
+ "img_r10": ir10,
319
+ "img_r_mean": ir_mean,
320
+ "r_mean": r_mean,
321
+ }
322
+ return eval_result
323
+
324
+
325
+ @torch.no_grad()
326
+ def format_output(
327
+ image_to_text_scores,
328
+ text_to_image_scores,
329
+ image_dataset,
330
+ text_dataset,
331
+ ):
332
+ image_to_text_output = {}
333
+ for index, score in enumerate(image_to_text_scores):
334
+ image = image_dataset.images[index]
335
+ top10_ids = torch.flip(torch.argsort(score), dims=[0])[:10]
336
+ top10_text = [text_dataset.text[i] for i in top10_ids]
337
+ image_to_text_output[index] = {
338
+ "image": image,
339
+ "output": top10_text,
340
+ }
341
+ text_to_image_output = {}
342
+ for index, score in enumerate(text_to_image_scores):
343
+ text = text_dataset.text[index]
344
+ top10_ids = torch.flip(torch.argsort(score), dims=[0])[:10]
345
+ top10_images = [image_dataset.images[i] for i in top10_ids]
346
+ text_to_image_output[index] = {
347
+ "text": text,
348
+ "output": top10_images,
349
+ }
350
+ return image_to_text_output, text_to_image_output
351
+
352
+
353
+ def main():
354
+ parser = argparse.ArgumentParser()
355
+ parser.add_argument("--config", default="./examples/albef/configs/retrieval.yaml")
356
+ args = parser.parse_args()
357
+ config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
358
+
359
+ init_distributed_mode(config)
360
+ device = torch.device(config["device"])
361
+
362
+ seed = config["seed"] + get_rank()
363
+ torch.manual_seed(seed)
364
+ random.seed(seed)
365
+ cudnn.benchmark = True
366
+
367
+ datamodule = RetrievalDataModule(**config["datamodule_args"])
368
+ model = albef_model_for_retrieval(config, pretrained=True)
369
+ model = model.to(device)
370
+ if is_dist_avail_and_initialized():
371
+ model = torch.nn.parallel.DistributedDataParallel(
372
+ model, device_ids=[config["gpu"]]
373
+ )
374
+
375
+ train(model, datamodule, config["training_args"], device)
376
+ image_to_text_scores, text_to_image_scores = evaluation(
377
+ model, datamodule, config["eval_args"], device
378
+ )
379
+ val_result = itm_eval(
380
+ image_to_text_scores,
381
+ text_to_image_scores,
382
+ datamodule.image_dataset.image_to_text,
383
+ datamodule.text_dataset.text_to_image,
384
+ )
385
+ image_to_text_output, text_to_image_output = format_output(
386
+ image_to_text_scores,
387
+ text_to_image_scores,
388
+ datamodule.image_dataset,
389
+ datamodule.text_dataset,
390
+ )
391
+ result = {
392
+ "image_to_text_output": image_to_text_output,
393
+ "text_to_image_output": text_to_image_output,
394
+ **val_result,
395
+ }
396
+ torch.save(result, config["output_path"])
397
+
398
+
399
+ if __name__ == "__main__":
400
+ main()
finetune_vqa.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import datetime
9
+ import os
10
+ import random
11
+ import time
12
+
13
+ import ruamel.yaml as yaml
14
+ import torch
15
+ import torch.backends.cudnn as cudnn
16
+ import torch.distributed as dist
17
+ from data.vqa_datamodules import VQADataModule
18
+ from model import albef_model_for_vqa
19
+ from torch.optim import AdamW
20
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
21
+
22
+ from utils import (
23
+ add_weight_decay,
24
+ get_rank,
25
+ get_world_size,
26
+ init_distributed_mode,
27
+ is_dist_avail_and_initialized,
28
+ is_main_process,
29
+ save_result,
30
+ )
31
+
32
+
33
+ def train(model, datamodule, args, device):
34
+ model_without_ddp = model.module if is_dist_avail_and_initialized() else model
35
+ model.train()
36
+
37
+ optimizer_params = add_weight_decay(model, args["weight_decay"])
38
+ optimizer = AdamW(optimizer_params, lr=args["lr"])
39
+ scheduler = CosineAnnealingWarmRestarts(
40
+ optimizer, T_0=args["max_epochs"], eta_min=args["min_lr"]
41
+ )
42
+
43
+ step_size = args["step_size"]
44
+ warmup_steps = args["warmup_steps"]
45
+ warmup_iterations = warmup_steps * step_size
46
+
47
+ data_loader = datamodule.train_dataloader(
48
+ is_distributed=is_dist_avail_and_initialized(),
49
+ num_tasks=get_world_size(),
50
+ global_rank=get_rank(),
51
+ )
52
+
53
+ start_time = time.time()
54
+
55
+ for epoch in range(args["max_epochs"]):
56
+ if is_dist_avail_and_initialized():
57
+ data_loader.sampler.set_epoch(epoch)
58
+
59
+ if epoch > 0:
60
+ scheduler.step(epoch + warmup_steps)
61
+
62
+ for batch, (
63
+ images,
64
+ questions,
65
+ questions_atts,
66
+ answers,
67
+ answers_atts,
68
+ ans_weights,
69
+ ans_lengths,
70
+ ) in enumerate(data_loader):
71
+ if epoch > 0:
72
+ alpha = args["alpha"]
73
+ else:
74
+ alpha = args["alpha"] * min(1, batch / len(data_loader))
75
+
76
+ images = images.to(device, non_blocking=True)
77
+ questions = questions.to(device)
78
+ questions_atts = questions_atts.to(device)
79
+ answers = answers.to(device)
80
+ answers_atts = answers_atts.to(device)
81
+ ans_weights = ans_weights.to(device)
82
+
83
+ loss = model(
84
+ images,
85
+ questions,
86
+ questions_atts,
87
+ answers,
88
+ answers_atts,
89
+ ans_weights=ans_weights,
90
+ ans_lengths=ans_lengths,
91
+ alpha=alpha,
92
+ is_train=True,
93
+ )
94
+
95
+ optimizer.zero_grad()
96
+ loss.backward()
97
+ optimizer.step()
98
+
99
+ if epoch == 0 and batch % step_size == 0 and batch <= warmup_iterations:
100
+ scheduler.step(batch // step_size)
101
+
102
+ if batch % args["log_every_n_steps"] == 0:
103
+ total_time = time.time() - start_time
104
+ time_str = "time {},".format(
105
+ datetime.timedelta(seconds=int(total_time))
106
+ )
107
+ epoch_str = "epoch {}/{},".format(epoch, args["max_epochs"])
108
+ batch_str = "batch {}/{},".format(batch, len(data_loader))
109
+ loss_str = "loss {}".format(loss.item())
110
+ print(time_str, epoch_str, batch_str, loss_str)
111
+
112
+ if is_main_process():
113
+ save_obj = {
114
+ "model": model_without_ddp.state_dict(),
115
+ "optimizer": optimizer.state_dict(),
116
+ "scheduler": scheduler.state_dict(),
117
+ "epoch": epoch,
118
+ }
119
+ torch.save(
120
+ save_obj,
121
+ os.path.join(args["checkpoint_root"], "vqa_checkpoint_%02d.pt" % epoch),
122
+ )
123
+
124
+ if is_dist_avail_and_initialized():
125
+ dist.barrier()
126
+
127
+
128
+ @torch.no_grad()
129
+ def evaluation(model, datamodule, args, device):
130
+ model.eval()
131
+
132
+ result = []
133
+
134
+ answer_list = datamodule.test_dataset.answer_list
135
+ answer_input_ids = datamodule.test_dataset.answer_input_ids.to(device)
136
+ answer_atts = datamodule.test_dataset.answer_attention_mask.to(device)
137
+ data_loader = datamodule.test_dataloader(
138
+ is_distributed=is_dist_avail_and_initialized(),
139
+ num_tasks=get_world_size(),
140
+ global_rank=get_rank(),
141
+ )
142
+
143
+ start_time = time.time()
144
+
145
+ for batch, (img, ques, ques_atts, ques_ids) in enumerate(data_loader):
146
+ img = img.to(device, non_blocking=True)
147
+ ques = ques.to(device)
148
+ ques_atts = ques_atts.to(device)
149
+
150
+ topk_ids, topk_probs = model(
151
+ img,
152
+ ques,
153
+ ques_atts,
154
+ answer_input_ids,
155
+ answer_atts,
156
+ k=args["k_test"],
157
+ is_train=False,
158
+ )
159
+
160
+ for ques_id, topk_id, topk_prob in zip(ques_ids, topk_ids, topk_probs):
161
+ _, pred = topk_prob.max(dim=0)
162
+ result.append(
163
+ {"question_id": ques_id, "answer": answer_list[topk_id[pred]]}
164
+ )
165
+
166
+ if batch % args["log_every_n_steps"] == 0:
167
+ total_time = time.time() - start_time
168
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
169
+ print(
170
+ "time {}, batch {}/{}".format(total_time_str, batch, len(data_loader))
171
+ )
172
+
173
+ return result
174
+
175
+
176
+ def main():
177
+ parser = argparse.ArgumentParser()
178
+ parser.add_argument("--config", default="./examples/albef/configs/vqa.yaml")
179
+ args = parser.parse_args()
180
+ config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
181
+
182
+ init_distributed_mode(config)
183
+ device = torch.device(config["device"])
184
+
185
+ seed = config["seed"] + get_rank()
186
+ torch.manual_seed(seed)
187
+ random.seed(seed)
188
+ cudnn.benchmark = True
189
+
190
+ datamodule = VQADataModule(**config["datamodule_args"])
191
+ model = albef_model_for_vqa(config, pretrained=True)
192
+ model = model.to(device)
193
+ if is_dist_avail_and_initialized():
194
+ model = torch.nn.parallel.DistributedDataParallel(
195
+ model, device_ids=[config["gpu"]]
196
+ )
197
+
198
+ train(model, datamodule, config["training_args"], device)
199
+ result = evaluation(model, datamodule, config["eval_args"], device)
200
+ save_result(result, config["output_root"], "vqa_output")
201
+
202
+
203
+ if __name__ == "__main__":
204
+ main()
images/COCO_val2014_000000026348.jpg ADDED
images/COCO_val2014_000000057222.jpg ADDED
images/COCO_val2014_000000111207.jpg ADDED
images/COCO_val2014_000000159269.jpg ADDED
images/COCO_val2014_000000184359.jpg ADDED
images/COCO_val2014_000000407072.jpg ADDED
images/COCO_val2014_000000473994.jpg ADDED
images/COCO_val2014_000000552075.jpg ADDED
model.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import copy
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn, Tensor
13
+ from torchmultimodal.models.albef.image_encoder import ALBEFVisionEncoder
14
+ from torchmultimodal.models.albef.model import ALBEFModel, ALBEFModelWithSimilarity
15
+ from torchmultimodal.models.albef.multimodal_encoder import ALBEFMultimodalEncoder
16
+ from torchmultimodal.modules.encoders.bert_text_encoder import bert_text_encoder
17
+ from torchmultimodal.modules.layers.text_embedding import BERTTextEmbeddings
18
+ from torchmultimodal.modules.losses.albef import (
19
+ CausalLanguageModelingLoss,
20
+ ImageTextContrastiveLoss,
21
+ )
22
+ from torchmultimodal.utils.attention import get_causal_attention_mask
23
+ from torchmultimodal.utils.common import momentum_update, remove_grad
24
+
25
+
26
+ _ALBEF_PRETRAINED_URLS = {
27
+ "vqa": "https://download.pytorch.org/models/multimodal/albef/pretrained_vqa_checkpoint.pt",
28
+ "retrieval": "https://download.pytorch.org/models/multimodal/albef/pretrained_retrieval_checkpoint.pt",
29
+ }
30
+
31
+
32
+ class PredictionHead(nn.Module):
33
+ """
34
+ Predict the following token autoregressively.
35
+
36
+ Args:
37
+ vocab_size (int): The number of different tokens the prediction_head can predict.
38
+ hidden_size (int): The hidden size of the prediction_head.
39
+ layer_norm_eps (float): The epsilon used by the prediction_head normalization layer.
40
+ transform_act_fn (Callable[[Tensor], Tensor]): The activation function in the prediction_head.
41
+
42
+ Inputs:
43
+ hidden_states (Tensor): The hidden states of preceding tokens.
44
+
45
+ Returns:
46
+ Tensor: Prediction scores for the following token.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ vocab_size: int = 30522,
52
+ hidden_size: int = 768,
53
+ layer_norm_eps: float = 1e-12,
54
+ transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu,
55
+ ) -> None:
56
+ super().__init__()
57
+ self.dense = nn.Linear(hidden_size, hidden_size)
58
+ self.transform_act_fn = transform_act_fn
59
+ self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
60
+ self.decoder = nn.Linear(hidden_size, vocab_size)
61
+
62
+ def forward(self, hidden_states: Tensor) -> Tensor:
63
+ hidden_states = self.dense(hidden_states)
64
+ hidden_states = self.transform_act_fn(hidden_states)
65
+ hidden_states = self.layer_norm(hidden_states)
66
+ hidden_states = self.decoder(hidden_states)
67
+ return hidden_states
68
+
69
+
70
+ class ALBEFDecoder(nn.Module):
71
+ """
72
+ Generate the prediction scores for answers from image and question hidden states.
73
+
74
+ Args:
75
+ text_embeddings (ALBEFTextEmbeddings): Instantiated ALBEFTextEmbeddings.
76
+ multimodal_encoder (ALBEFMultimodalEncoder): Instantiated ALBEFMultimodalEncoder.
77
+ prediction_head (PredictionHead): Instantiated PredictionHead.
78
+
79
+ Inputs:
80
+ input_ids (Tensor of shape (batch_size, seq_len)):
81
+ Input ids for input text tokens.
82
+ attention_mask (Tensor of shape (batch_size, seq_len)):
83
+ Input attention mask to avoid performing attention on padding token indices.
84
+ encoder_hidden_states (Tensor of shape (batch_size, encoder_seq_len, hidden_size)):
85
+ The encoder hidden states.
86
+ encoder_attention_mask (Tensor of shape (batch_size, encoder_seq_len)):
87
+ The attention mask for encoder hidden states.
88
+
89
+ Returns:
90
+ Tensor: Prediction scores for answers.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ text_embeddings: BERTTextEmbeddings,
96
+ multimodal_encoder: ALBEFMultimodalEncoder,
97
+ prediction_head: PredictionHead,
98
+ ) -> None:
99
+ super().__init__()
100
+ self.text_embeddings = text_embeddings
101
+ self.multimodal_encoder = multimodal_encoder
102
+ self.prediction_head = prediction_head
103
+
104
+ def get_extended_attention_mask_for_decoder(self, attention_mask: Tensor) -> Tensor:
105
+ """
106
+ Apply a causal mask in addition to the padding mask and make the mask broadcastable,
107
+ such that future and masked tokens are ignored.
108
+
109
+ Args:
110
+ attention_mask (Tensor):
111
+ Padding mask with ones indicating tokens to attend to, zeros for tokens to ignore.
112
+
113
+ Returns:
114
+ extended_attention_mask (Tensor):
115
+ The broadcastable attention mask, with the same dtype as ``attention_mask.dtype``.
116
+ """
117
+ device = attention_mask.device
118
+ batch_size, seq_length = attention_mask.shape
119
+ causal_mask = get_causal_attention_mask(seq_length).to(device)
120
+ causal_mask = causal_mask.repeat(batch_size, 1).view(
121
+ batch_size, seq_length, seq_length
122
+ )
123
+ extended_attention_mask = (
124
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
125
+ )
126
+ extended_attention_mask = extended_attention_mask.to(dtype=attention_mask.dtype)
127
+ return extended_attention_mask
128
+
129
+ def forward(
130
+ self,
131
+ input_ids: Tensor,
132
+ attention_mask: Tensor,
133
+ encoder_hidden_states: Tensor,
134
+ encoder_attention_mask: Tensor,
135
+ ) -> Tensor:
136
+ hidden_states = self.text_embeddings(input_ids)
137
+ attention_mask = self.get_extended_attention_mask_for_decoder(attention_mask)
138
+ decoder_output = self.multimodal_encoder(
139
+ hidden_states=hidden_states,
140
+ attention_mask=attention_mask,
141
+ encoder_hidden_states=encoder_hidden_states,
142
+ encoder_attention_mask=encoder_attention_mask,
143
+ )
144
+ prediction_scores = self.prediction_head(decoder_output)
145
+ return prediction_scores
146
+
147
+
148
+ class ALBEFModelForVQA(nn.Module):
149
+ """
150
+ ALBEF Model for VQA finetuning and inference.
151
+
152
+ Args:
153
+ model (ALBEFModel): Instantiated ALBEFModel.
154
+ answer_decoder (ALBEFDecoder): Instantiated ALBEFDecoder.
155
+ loss (CausalLanguageModelingLoss): Instantiated CausalLanguageModelingLoss.
156
+
157
+ Inputs:
158
+ image (Tensor of shape (B, C, H, W)): Image features.
159
+ question (Tensor of shape (B, L)): Question text features.
160
+ question_atts (Tensor of shape (B, L)): Question attention mask.
161
+ answers (Tensor of shape (N, M)): Answer text features.
162
+ answers_atts (Tensor of shape (N, M)): Answer attention mask.
163
+ ans_weights (Optional[Tensor] of shape (N)): Weights for each answer.
164
+ Required if is_train is True.
165
+ ans_lengths (Optional[List[int]] of length B): Number of answers for each question.
166
+ ans_lengths should sum to N.
167
+ Required if is_train is True.
168
+ alpha (Optional[float]): The interpolation value between clm_loss and loss_distill.
169
+ Required if is_train is True.
170
+ k (Optional[int]): The number of answers to return for inference.
171
+ Required if is_train is False.
172
+ is_train (Optional[bool]): Whether the model is in training.
173
+
174
+ Returns:
175
+ is_train is True:
176
+ Tensor: The masked language modeling loss for input.
177
+ is_train is False:
178
+ Tuple[Tensor, Tensor]: The ids and probabilities for the top k predicted answers.
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ model: ALBEFModel,
184
+ answer_decoder: ALBEFDecoder,
185
+ loss: CausalLanguageModelingLoss,
186
+ ) -> None:
187
+ super().__init__()
188
+ self.model = model
189
+ self.answer_decoder = answer_decoder
190
+ self.loss = loss
191
+ self.answer_decoder_m = copy.deepcopy(self.answer_decoder)
192
+ remove_grad(
193
+ self.answer_decoder_m
194
+ ) # remove gradient for the momentum decoder model
195
+
196
+ def _train_forward(
197
+ self,
198
+ image: Tensor,
199
+ question: Tensor,
200
+ question_atts: Tensor,
201
+ answers: Tensor,
202
+ answers_atts: Tensor,
203
+ ans_weights: Tensor,
204
+ ans_lengths: List[int],
205
+ alpha: float,
206
+ ) -> Tensor:
207
+ """
208
+ Forward step for training. Encode the inputs with the ALBEFModel.
209
+ Generate pseudo-targets using answer_decoder_m (momentum decoder model).
210
+ Generate answer predictions using answer_decoder.
211
+ Compute masked language modeling loss of the predictions using answers as labels,
212
+ pseudo-targets as soft-labels, and alpha as their interpolation value.
213
+
214
+ Inputs:
215
+ image (Tensor of shape (B, C, H, W)): Image features.
216
+ question (Tensor of shape (B, L)): Question text features.
217
+ question_atts (Tensor of shape (B, L)): Question attention mask.
218
+ answers (Tensor of shape (N, M)): Answer text features.
219
+ answers_atts (Tensor of shape (N, M)): Answer attention mask.
220
+ ans_weights (Tensor of shape (N)): Weights for each answer.
221
+ ans_lengths (List[int] of length B): Number of answers for each question.
222
+ ans_lengths should sum to N.
223
+ alpha (float): The interpolation value between clm_loss and loss_distill.
224
+
225
+ Returns:
226
+ Tensor: The masked language modeling loss for input.
227
+ """
228
+ # get image-question embeddings from the ALBEFModel and format it to match the ans_lengths
229
+ encoder_outputs = self.model(image, question, question_atts)
230
+ (
231
+ encoder_hidden_states,
232
+ encoder_hidden_states_m,
233
+ encoder_attention_mask,
234
+ ) = self._encoder_hidden_states(
235
+ encoder_outputs.multimodal_embeddings,
236
+ encoder_outputs.multimodal_embeddings_m,
237
+ question_atts,
238
+ ans_lengths,
239
+ )
240
+
241
+ # use the momentum model to generate pseudo-targets
242
+ with torch.no_grad():
243
+ momentum_update(
244
+ self.answer_decoder, self.answer_decoder_m, self.model.momentum
245
+ )
246
+ prediction_scores_m = self.answer_decoder_m(
247
+ input_ids=answers,
248
+ attention_mask=answers_atts,
249
+ encoder_hidden_states=encoder_hidden_states_m,
250
+ encoder_attention_mask=encoder_attention_mask,
251
+ )
252
+
253
+ # generate answer predictions
254
+ prediction_scores = self.answer_decoder(
255
+ input_ids=answers,
256
+ attention_mask=answers_atts,
257
+ encoder_hidden_states=encoder_hidden_states,
258
+ encoder_attention_mask=encoder_attention_mask,
259
+ )
260
+
261
+ # compute masked language modeling loss from the prediction scores
262
+ labels = answers.masked_fill(answers == 0, self.loss.mask_token_id)
263
+ loss = self.loss(labels, prediction_scores, prediction_scores_m, alpha)
264
+ loss = ans_weights * loss
265
+ loss = loss.sum() / image.size(0)
266
+ return loss
267
+
268
+ def _eval_forward(
269
+ self,
270
+ image: Tensor,
271
+ question: Tensor,
272
+ question_atts: Tensor,
273
+ answers: Tensor,
274
+ answer_atts: Tensor,
275
+ k: int = 128,
276
+ ) -> Tuple[Tensor, Tensor]:
277
+ """
278
+ Forward step for evaluation. Encode the inputs with the ALBEFModel.
279
+ Generate answer autoregressively using the decoder, starting with the [CLS] token.
280
+ Compute the answer ids and their perspective probabilities of the top k predictions.
281
+
282
+ Inputs:
283
+ image (Tensor of shape (B, C, H, W)): Image features.
284
+ question (Tensor of shape (B, L)): Question text features.
285
+ question_atts (Tensor of shape (B, L)): Question attention mask.
286
+ answers (Tensor of shape (N, M)): Answer text features.
287
+ answer_atts (Tensor of shape (N, M)): Answer attention mask.
288
+ k (int): The number of answers to return for inference.
289
+
290
+ Returns:
291
+ Tuple[Tensor, Tensor]: The ids and probabilities for the top k predicted answers.
292
+ """
293
+ # get multimodal embeddings from the ALBEFModel and
294
+ # feed it to the decoder as cross attention
295
+ encoder_outputs = self.model(image, question, question_atts)
296
+
297
+ # use cls token as the decoder's initial input token
298
+ num_ques = question.size(0)
299
+ start_ids = answers[0, 0].repeat(num_ques, 1)
300
+ atts = torch.ones(start_ids.shape).to(image.device)
301
+
302
+ # auto-regressively generates the answer
303
+ prediction_scores = self.answer_decoder(
304
+ input_ids=start_ids,
305
+ attention_mask=atts,
306
+ encoder_hidden_states=encoder_outputs.multimodal_embeddings,
307
+ encoder_attention_mask=question_atts,
308
+ )
309
+
310
+ logits = prediction_scores[:, 0, :]
311
+ answer_first_token = answers[:, 1]
312
+ prob_first_token = F.softmax(logits, dim=1).index_select(
313
+ dim=1, index=answer_first_token
314
+ )
315
+ topk_probs, topk_ids = prob_first_token.topk(k, dim=1)
316
+
317
+ input_ids = []
318
+ input_atts = []
319
+ for topk_id in topk_ids:
320
+ input_ids.append(answers.index_select(dim=0, index=topk_id))
321
+ input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
322
+ input_ids = torch.cat(input_ids)
323
+ input_atts = torch.cat(input_atts)
324
+ targets_ids = input_ids.masked_fill(input_ids == 0, self.loss.mask_token_id)
325
+
326
+ question_states = encoder_outputs.multimodal_embeddings.repeat_interleave(
327
+ k, dim=0
328
+ )
329
+ question_atts = question_atts.repeat_interleave(k, dim=0)
330
+
331
+ prediction_scores = self.answer_decoder(
332
+ input_ids=input_ids,
333
+ attention_mask=input_atts,
334
+ encoder_hidden_states=question_states,
335
+ encoder_attention_mask=question_atts,
336
+ )
337
+
338
+ answer_loss = self.loss(targets_ids, prediction_scores)
339
+ answer_loss = answer_loss.view(input_ids.size(0), -1)
340
+
341
+ # topk_prob: first token probability
342
+ topk_probs = topk_probs.view(-1, 1)
343
+ log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1)
344
+
345
+ # re-calculate log probabilities for the answer sequences using chain rule
346
+ log_probs_sum = log_probs.sum(1)
347
+ log_probs_sum = log_probs_sum.view(num_ques, k)
348
+
349
+ topk_probs = F.softmax(log_probs_sum, dim=-1)
350
+
351
+ # get top-k after re-ranking
352
+ topk_probs, rerank_id = topk_probs.topk(k, dim=1)
353
+ topk_ids = torch.gather(topk_ids, 1, rerank_id)
354
+
355
+ return topk_ids, topk_probs
356
+
357
+ def _encoder_hidden_states(
358
+ self,
359
+ multimodal_embeds: Tensor,
360
+ multimodal_embeds_m: Tensor,
361
+ question_atts: Tensor,
362
+ ans_lengths: List[int],
363
+ ) -> Tuple[Tensor, Tensor, Tensor]:
364
+ """
365
+ Repeat each image-question input, repeat its embedding and mask to match the number of answers it has.
366
+
367
+ Args:
368
+ multimodal_embeds (Tensor): Image-question embeddings.
369
+ multimodal_embeds_m (Tensor): Image-question embeddings from the momentum model.
370
+ question_atts (Tensor): Question attention mask.
371
+ ans_lengths (List[int]): The number of answers each image-question input has.
372
+
373
+ Returns:
374
+ encoder_hidden_states (Tensor): Image-question embeddings after the repetition.
375
+ encoder_hidden_states_m (Tensor): Image-question embeddings from the momentum model after the repetition.
376
+ encoder_attention_mask (Tensor): Question attention mask after the repetition.
377
+ """
378
+ encoder_hidden_states = []
379
+ encoder_attention_mask = []
380
+ for b, n in enumerate(ans_lengths):
381
+ encoder_hidden_states += [multimodal_embeds[b]] * n
382
+ encoder_attention_mask += [question_atts[b]] * n
383
+ encoder_hidden_states = torch.stack(encoder_hidden_states)
384
+ encoder_attention_mask = torch.stack(encoder_attention_mask)
385
+
386
+ with torch.no_grad():
387
+ encoder_hidden_states_m = []
388
+ for b, n in enumerate(ans_lengths):
389
+ encoder_hidden_states_m += [multimodal_embeds_m[b]] * n
390
+ encoder_hidden_states_m = torch.stack(encoder_hidden_states_m)
391
+
392
+ return encoder_hidden_states, encoder_hidden_states_m, encoder_attention_mask
393
+
394
+ def forward(
395
+ self,
396
+ image: Tensor,
397
+ question: Tensor,
398
+ question_atts: Tensor,
399
+ answers: Tensor,
400
+ answers_atts: Tensor,
401
+ ans_weights: Optional[Tensor] = None,
402
+ ans_lengths: Optional[List[int]] = None,
403
+ alpha: Optional[float] = 0.0,
404
+ k: Optional[int] = 128,
405
+ is_train: Optional[bool] = True,
406
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
407
+ if is_train:
408
+ return self._train_forward(
409
+ image,
410
+ question,
411
+ question_atts,
412
+ answers,
413
+ answers_atts,
414
+ ans_weights,
415
+ ans_lengths,
416
+ alpha,
417
+ )
418
+ else:
419
+ return self._eval_forward(
420
+ image,
421
+ question,
422
+ question_atts,
423
+ answers,
424
+ answers_atts,
425
+ k,
426
+ )
427
+
428
+
429
+ class ALBEFModelForRetrieval(nn.Module):
430
+ """
431
+ ALBEF Model for Retrieval finetuning and inference.
432
+ In training mode, the forward step computes image-text contrastive loss and
433
+ image-text matching loss.
434
+ In evaluation mode, the forward step takes 3 types of input:
435
+ image: encode image input, project and normalize the embeddings.
436
+ text: encode text input, project and normalize the embeddings.
437
+ multimodal: create multimodal embeddings from image and text
438
+ embeddings, and compute image-text matching scores.
439
+
440
+ Args:
441
+ model_with_similarity (ALBEFModelWithSimilarity): Instantiated ALBEFModelWithSimilarity.
442
+ itc_loss (ImageTextContrastiveLoss): Instantiated ImageTextContrastiveLoss.
443
+ hidden_size (int): Dimensionality of encoder outputs.
444
+
445
+ Inputs:
446
+ image (Optional[Tensor] of shape (B, C, H, W)): Image features.
447
+ Required if is_train is True.
448
+ Required if input_type is "image" or "multimodal".
449
+ text (Optional[Tensor] of shape (B, L)): Text features.
450
+ Required if is_train is True.
451
+ Required if input_type is "text" or "multimodal".
452
+ text_atts (Tensor of shape (B, L)): Text attention mask.
453
+ Required if is_train is True.
454
+ Required if input_type is "text" or "multimodal".
455
+ idx (Tensor of shape (B)): Identifier for each image sample.
456
+ Required if is_train is True.
457
+ alpha (Optional[float]): The interpolation value between clm_loss and loss_distill.
458
+ Default is 0.
459
+ input_type (Optional[str]): "image", "text", or "multimodal" indicating the encoding type.
460
+ Required if is_train is False.
461
+ is_train (Optional[bool]): Whether the model is in training.
462
+ Default is True.
463
+
464
+ Returns:
465
+ is_train is True:
466
+ Tensor: The sum of itc loss and itm loss.
467
+ is_train is False:
468
+ input_type is "image":
469
+ Tuple[Tensor, Tensor]: Image embeddings and projected image features.
470
+ input_type is "text":
471
+ Tuple[Tensor, Tensor]: Text embeddings and projected text features.
472
+ input_type is "multimodal"
473
+ Tensor: Scores for the retrieval task.
474
+ """
475
+
476
+ def __init__(
477
+ self,
478
+ model_with_similarity: ALBEFModelWithSimilarity,
479
+ itc_loss: ImageTextContrastiveLoss,
480
+ hidden_size: int,
481
+ ) -> None:
482
+ super().__init__()
483
+ self.model_with_similarity = model_with_similarity
484
+ self.itc_loss = itc_loss
485
+ self.itm_head = nn.Linear(hidden_size, 2)
486
+
487
+ def _train_forward(
488
+ self,
489
+ image: Tensor,
490
+ text: Tensor,
491
+ text_atts: Tensor,
492
+ idx: Tensor,
493
+ alpha: float,
494
+ ) -> Tensor:
495
+ encoder_output = self.model_with_similarity(image, text, text_atts, idx)
496
+
497
+ # compute image-text contrastive loss
498
+ similarity_outputs = encoder_output.similarity
499
+ similarity_targets = encoder_output.sim_targets
500
+ itc_loss = self.itc_loss(
501
+ similarity_outputs.sim_i2t,
502
+ similarity_outputs.sim_t2i,
503
+ similarity_outputs.sim_i2t_m,
504
+ similarity_outputs.sim_t2i_m,
505
+ similarity_targets,
506
+ alpha,
507
+ )
508
+
509
+ # compute image-text matching loss
510
+ pos_embeddings = encoder_output.multimodal_embeddings[:, 0, :]
511
+ neg_embeddings = encoder_output.multimodal_embeddings_neg[:, 0, :]
512
+ vl_embeddings = torch.cat([pos_embeddings, neg_embeddings], dim=0)
513
+ vl_output = self.itm_head(vl_embeddings)
514
+ itm_labels = torch.cat(
515
+ [
516
+ torch.ones(pos_embeddings.size(0), dtype=torch.long),
517
+ torch.zeros(neg_embeddings.size(0), dtype=torch.long),
518
+ ],
519
+ dim=0,
520
+ ).to(vl_embeddings.device)
521
+ itm_loss = F.cross_entropy(vl_output, itm_labels)
522
+
523
+ loss = itc_loss + itm_loss
524
+ return loss
525
+
526
+ def _encode_image(
527
+ self,
528
+ image: Tensor,
529
+ ) -> Tuple[Tensor, Tensor]:
530
+ image_embed = self.model_with_similarity.albef_model.vision_encoder(image)
531
+ image_feat = F.normalize(
532
+ self.model_with_similarity.vision_proj(image_embed[:, 0, :]), dim=-1
533
+ )
534
+ return image_embed, image_feat
535
+
536
+ def _encode_text(
537
+ self,
538
+ text: Tensor,
539
+ text_atts: Tensor,
540
+ ) -> Tuple[Tensor, Tensor]:
541
+ text_embed = self.model_with_similarity.albef_model.text_encoder(
542
+ text, text_atts
543
+ ).last_hidden_state
544
+ text_feat = F.normalize(
545
+ self.model_with_similarity.text_proj(text_embed[:, 0, :]), dim=-1
546
+ )
547
+ return text_embed, text_feat
548
+
549
+ def _image_text_matching_score(
550
+ self,
551
+ image: Tensor,
552
+ text: Tensor,
553
+ text_atts: Tensor,
554
+ ) -> Tensor:
555
+ multimodal_embeds = self.model_with_similarity.albef_model.multimodal_encoder(
556
+ text,
557
+ text_atts,
558
+ image,
559
+ )
560
+ score = self.itm_head(multimodal_embeds[:, 0, :])[:, 1]
561
+ return score
562
+
563
+ def _eval_forward(
564
+ self,
565
+ input_type: str,
566
+ image: Optional[Tensor],
567
+ text: Optional[Tensor],
568
+ text_atts: Optional[Tensor],
569
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
570
+ if input_type == "image":
571
+ assert image is not None, "image input tensor cannot be None"
572
+ return self._encode_image(image)
573
+
574
+ elif input_type == "text":
575
+ assert (
576
+ text is not None and text_atts is not None
577
+ ), "text and text attention mask cannot be None"
578
+ return self._encode_text(text, text_atts)
579
+
580
+ elif input_type == "multimodal":
581
+ assert (
582
+ image is not None and text is not None and text_atts is not None
583
+ ), "image embeddings, text embeddings, and text attention mask cannot be None"
584
+ return self._image_text_matching_score(image, text, text_atts)
585
+
586
+ else:
587
+ raise ValueError("input_type must be image, text, or multimodal")
588
+
589
+ def forward(
590
+ self,
591
+ image: Optional[Tensor] = None,
592
+ text: Optional[Tensor] = None,
593
+ text_atts: Optional[Tensor] = None,
594
+ idx: Optional[Tensor] = None,
595
+ alpha: Optional[Tensor] = 0.0,
596
+ input_type: Optional[str] = None,
597
+ is_train: Optional[bool] = True,
598
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
599
+ if is_train:
600
+ return self._train_forward(
601
+ image,
602
+ text,
603
+ text_atts,
604
+ idx,
605
+ alpha,
606
+ )
607
+ else:
608
+ return self._eval_forward(
609
+ input_type,
610
+ image,
611
+ text,
612
+ text_atts,
613
+ )
614
+
615
+
616
+ def albef_model_for_vqa(
617
+ config: Dict[str, Any], pretrained: bool = False
618
+ ) -> ALBEFModelForVQA:
619
+ vision_encoder = ALBEFVisionEncoder(**config["vision_encoder_args"])
620
+ text_encoder = bert_text_encoder(**config["text_encoder_args"])
621
+ question_multimodal_encoder = ALBEFMultimodalEncoder(
622
+ **config["multimodal_encoder_args"]
623
+ )
624
+ text_embeddings = BERTTextEmbeddings(**config["text_embeddings_args"])
625
+ answer_multimodal_encoder = ALBEFMultimodalEncoder(
626
+ **config["multimodal_encoder_args"]
627
+ )
628
+ prediction_head = PredictionHead(**config["prediction_head_args"])
629
+ albef_model = ALBEFModel(vision_encoder, text_encoder, question_multimodal_encoder)
630
+ decoder = ALBEFDecoder(text_embeddings, answer_multimodal_encoder, prediction_head)
631
+ loss = CausalLanguageModelingLoss()
632
+ model = ALBEFModelForVQA(albef_model, decoder, loss)
633
+
634
+ if pretrained:
635
+ checkpoint = torch.hub.load_state_dict_from_url(
636
+ _ALBEF_PRETRAINED_URLS["vqa"], map_location="cpu"
637
+ )
638
+ model.load_state_dict(checkpoint)
639
+ return model
640
+
641
+
642
+ def albef_model_for_retrieval(
643
+ config: Dict[str, Any], pretrained: bool = False
644
+ ) -> ALBEFModelForRetrieval:
645
+ vision_encoder = ALBEFVisionEncoder(**config["vision_encoder_args"])
646
+ text_encoder = bert_text_encoder(**config["text_encoder_args"])
647
+ multimodal_encoder = ALBEFMultimodalEncoder(**config["multimodal_encoder_args"])
648
+ vision_proj = nn.Linear(**config["projection_args"])
649
+ text_proj = nn.Linear(**config["projection_args"])
650
+
651
+ albef_model = ALBEFModel(vision_encoder, text_encoder, multimodal_encoder)
652
+ albef_model_with_sim = ALBEFModelWithSimilarity(
653
+ albef_model, vision_proj, text_proj, **config["similarity_args"]
654
+ )
655
+ itc_loss = ImageTextContrastiveLoss()
656
+
657
+ model = ALBEFModelForRetrieval(
658
+ albef_model_with_sim, itc_loss, config["hidden_size"]
659
+ )
660
+
661
+ if pretrained:
662
+ checkpoint = torch.hub.load_state_dict_from_url(
663
+ _ALBEF_PRETRAINED_URLS["retrieval"], map_location="cpu"
664
+ )
665
+ model.load_state_dict(checkpoint)
666
+ return model
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ opencv-python==4.6.0.66
2
+ pytorch-lightning==1.6.0
3
+ Pillow==9.0.1
4
+ ruamel_yaml==0.17.21
5
+ transformers==4.24.0
utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # All rights reserved.
9
+ #
10
+ # This source code is licensed under the BSD-style license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+
13
+ import json
14
+ import os
15
+
16
+ import torch
17
+ import torch.distributed as dist
18
+ from torch import nn
19
+
20
+
21
+ def setup_for_distributed(is_master):
22
+ """
23
+ This function disables printing when not in master process
24
+ """
25
+ import builtins as __builtin__
26
+
27
+ builtin_print = __builtin__.print
28
+
29
+ def print(*args, **kwargs):
30
+ force = kwargs.pop("force", False)
31
+ if is_master or force:
32
+ builtin_print(*args, **kwargs)
33
+
34
+ __builtin__.print = print
35
+
36
+
37
+ def is_dist_avail_and_initialized():
38
+ if not dist.is_available():
39
+ return False
40
+ if not dist.is_initialized():
41
+ return False
42
+ return True
43
+
44
+
45
+ def get_world_size():
46
+ if not is_dist_avail_and_initialized():
47
+ return 1
48
+ return dist.get_world_size()
49
+
50
+
51
+ def get_rank():
52
+ if not is_dist_avail_and_initialized():
53
+ return 0
54
+ return dist.get_rank()
55
+
56
+
57
+ def is_main_process():
58
+ return get_rank() == 0
59
+
60
+
61
+ def init_distributed_mode(args):
62
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
63
+ args["rank"] = int(os.environ["RANK"])
64
+ args["world_size"] = int(os.environ["WORLD_SIZE"])
65
+ args["gpu"] = int(os.environ["LOCAL_RANK"])
66
+ elif "SLURM_PROCID" in os.environ:
67
+ args["rank"] = int(os.environ["SLURM_PROCID"])
68
+ args["gpu"] = args["rank"] % torch.cuda.device_count()
69
+ else:
70
+ print("Not using distributed mode")
71
+ args["distributed"] = False
72
+ return
73
+
74
+ args["distributed"] = True
75
+
76
+ torch.cuda.set_device(args["gpu"])
77
+ args["dist_backend"] = "nccl"
78
+ print(
79
+ "| distributed init (rank {}): {}".format(args["rank"], args["dist_url"]),
80
+ flush=True,
81
+ )
82
+ torch.distributed.init_process_group(
83
+ backend=args["dist_backend"],
84
+ init_method=args["dist_url"],
85
+ world_size=args["world_size"],
86
+ rank=args["rank"],
87
+ )
88
+ torch.distributed.barrier()
89
+ setup_for_distributed(args["rank"] == 0)
90
+
91
+
92
+ def save_result(result, directory, file_name):
93
+ rank_path = os.path.join(directory, "{}_rank_{}.json".format(file_name, get_rank()))
94
+ main_path = os.path.join(directory, "{}.json".format(file_name))
95
+ json.dump(result, open(rank_path, "w"))
96
+
97
+ if is_dist_avail_and_initialized():
98
+ dist.barrier()
99
+
100
+ if is_main_process():
101
+ result = []
102
+ for rank in range(get_world_size()):
103
+ rank_path = os.path.join(
104
+ directory, "{}_rank_{}.json".format(file_name, rank)
105
+ )
106
+ rank_res = json.load(open(rank_path, "r"))
107
+ result += rank_res
108
+ json.dump(result, open(main_path, "w"))
109
+
110
+ if is_dist_avail_and_initialized():
111
+ dist.barrier()
112
+
113
+
114
+ def add_weight_decay(model: nn.Module, weight_decay: float) -> None:
115
+ decay = []
116
+ no_decay = []
117
+ for name, param in model.named_parameters():
118
+ if not param.requires_grad:
119
+ continue # skip weight_decay for momentum models
120
+ if len(param.shape) == 1 or name.endswith(".bias"):
121
+ no_decay.append(param)
122
+ else:
123
+ decay.append(param)
124
+ return [
125
+ {"params": no_decay, "weight_decay": 0.0},
126
+ {"params": decay, "weight_decay": weight_decay},
127
+ ]
vqa_data.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [{"image": "images/COCO_val2014_000000184359.jpg", "question": "Is this a train station?", "answers": ["no", "no", "no", "no", "no", "no", "no", "no", "no", "no"]}, {"image": "images/COCO_val2014_000000407072.jpg", "question": "Was this photo taken at night?", "answers": ["yes", "yes", "yes", "yes", "yes", "yes", "yes", "yes", "yes", "yes"]}, {"image": "images/COCO_val2014_000000111207.jpg", "question": "How many photos in one?", "answers": ["2", "2", "2", "2", "2", "2", "2", "2", "2", "2"]}, {"image": "images/COCO_val2014_000000057222.jpg", "question": "How many bears are there?", "answers": ["2", "3", "3", "4", "2", "2", "3", "3", "2", "3"]}, {"image": "images/COCO_val2014_000000159269.jpg", "question": "What time of the day it is?", "answers": ["evening", "evening", "dusk", "sunset", "sunset", "dusk", "morning", "dusk", "evening", "4 pm"]}, {"image": "images/COCO_val2014_000000026348.jpg", "question": "What color is the refrigerator handle?", "answers": ["white", "white", "white", "white", "white", "white", "white", "white", "white", "white"]}, {"image": "images/COCO_val2014_000000473994.jpg", "question": "What does this animal eat?", "answers": ["meat", "dog food", "dog food", "dog food", "dog food", "dog food", "frisbee", "dog food", "frisbee", "dog food"]}, {"image": "images/COCO_val2014_000000552075.jpg", "question": "Who is wearing a hat?", "answers": ["no one", "woman", "no one", "nobody", "no one", "nobody", "no", "nobody", "nobody", "man"]}]