{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from util import UIDataset, Vocabulary\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from model import *\n", "from torchvision import transforms\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "dataset = UIDataset('./dataset/training', 'voc.pkl')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "net = Pix2Code().cuda()\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(net.parameters(), lr = 0.0001)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss: 2.9524800777435303, Epoch: 0\n", "Loss: 2.8884081840515137, Epoch: 0\n", "Loss: 2.829449415206909, Epoch: 0\n", "Loss: 2.740262269973755, Epoch: 0\n", "Loss: 2.621814489364624, Epoch: 0\n", "Loss: 2.5236337184906006, Epoch: 0\n", "Loss: 2.4583466053009033, Epoch: 0\n", "Loss: 2.4608116149902344, Epoch: 0\n", "Loss: 2.4273080825805664, Epoch: 0\n", "Loss: 2.385077714920044, Epoch: 0\n", "Loss: 2.3915295600891113, Epoch: 0\n", "Loss: 2.3580784797668457, Epoch: 0\n", "Loss: 2.353372812271118, Epoch: 0\n", "Loss: 2.3451170921325684, Epoch: 0\n", "Loss: 2.350393772125244, Epoch: 0\n", "Loss: 2.3572347164154053, Epoch: 0\n", "Loss: 2.3479745388031006, Epoch: 0\n", "Loss: 2.3289999961853027, Epoch: 0\n", "Loss: 2.3202321529388428, Epoch: 0\n", "Loss: 2.318603277206421, Epoch: 0\n", "Loss: 2.3135170936584473, Epoch: 0\n", "Loss: 2.2949447631835938, Epoch: 0\n", "Loss: 2.3082993030548096, Epoch: 0\n", "Loss: 2.3046371936798096, Epoch: 0\n", "Loss: 2.306605339050293, Epoch: 0\n", "Loss: 2.303187847137451, Epoch: 0\n", "Loss: 2.2787742614746094, Epoch: 0\n", "Loss: 2.300393581390381, Epoch: 0\n", "Loss: 2.2998340129852295, Epoch: 0\n", "Loss: 2.2719686031341553, Epoch: 0\n", "Loss: 2.2761404514312744, Epoch: 0\n", "Loss: 2.3393924236297607, Epoch: 0\n", "Loss: 2.2692413330078125, Epoch: 0\n", "Loss: 2.270902395248413, Epoch: 0\n", "Loss: 2.282170534133911, Epoch: 0\n", "Loss: 2.258744955062866, Epoch: 0\n", "Loss: 2.261420965194702, Epoch: 0\n", "Loss: 2.247162103652954, Epoch: 0\n", "Loss: 2.238236427307129, Epoch: 0\n", "Loss: 2.2235021591186523, Epoch: 0\n", "Loss: 2.2508511543273926, Epoch: 0\n", "Loss: 2.2157509326934814, Epoch: 0\n", "Loss: 2.2379965782165527, Epoch: 0\n", "Loss: 2.231006145477295, Epoch: 0\n", "Loss: 2.2043988704681396, Epoch: 0\n", "Loss: 2.2038233280181885, Epoch: 0\n", "Loss: 2.1575405597686768, Epoch: 0\n", "Loss: 2.1855974197387695, Epoch: 0\n", "Loss: 2.156130075454712, Epoch: 0\n", "Loss: 2.144746780395508, Epoch: 0\n", "Loss: 2.1328115463256836, Epoch: 0\n", "Loss: 2.1185996532440186, Epoch: 0\n", "Loss: 2.1840450763702393, Epoch: 0\n", "Loss: 2.0942633152008057, Epoch: 0\n", "Loss: 2.0785911083221436, Epoch: 0\n", "Loss: 2.0678865909576416, Epoch: 0\n", "Loss: 2.010136604309082, Epoch: 0\n", "Loss: 1.9797004461288452, Epoch: 0\n", "Loss: 1.9312745332717896, Epoch: 0\n", "Loss: 1.900559663772583, Epoch: 0\n", "Loss: 1.8531383275985718, Epoch: 0\n", "Loss: 1.9304338693618774, Epoch: 0\n", "Loss: 1.7415088415145874, Epoch: 0\n", "Loss: 1.6876497268676758, Epoch: 0\n", "Loss: 1.7912997007369995, Epoch: 0\n", "Loss: 1.5977957248687744, Epoch: 0\n", "Loss: 1.4403373003005981, Epoch: 0\n", "Loss: 1.4909701347351074, Epoch: 0\n", "Loss: 1.4378795623779297, Epoch: 0\n", "Loss: 1.3532891273498535, Epoch: 0\n", "Loss: 1.5037901401519775, Epoch: 0\n", "Loss: 1.1528806686401367, Epoch: 0\n", "Loss: 1.1900337934494019, Epoch: 0\n", "Loss: 1.2141656875610352, Epoch: 0\n", "Loss: 1.0680357217788696, Epoch: 0\n", "Loss: 1.336559534072876, Epoch: 0\n", "Loss: 1.1178141832351685, Epoch: 0\n", "Loss: 1.019561529159546, Epoch: 0\n", "Loss: 0.9516667723655701, Epoch: 0\n", "Loss: 0.9714117050170898, Epoch: 0\n", "Loss: 0.8897005319595337, Epoch: 0\n", "Loss: 1.470299482345581, Epoch: 0\n", "Loss: 0.8343897461891174, Epoch: 0\n", "Loss: 1.0053107738494873, Epoch: 0\n", "Loss: 0.8614187240600586, Epoch: 0\n", "Loss: 0.8838141560554504, Epoch: 0\n", "Loss: 0.7585976719856262, Epoch: 0\n", "Loss: 0.736517071723938, Epoch: 0\n", "Loss: 0.741409420967102, Epoch: 0\n", "Loss: 0.8206121325492859, Epoch: 0\n", "Loss: 0.7003164291381836, Epoch: 0\n", "Loss: 0.6972914934158325, Epoch: 0\n", "Loss: 0.653584361076355, Epoch: 0\n", "Loss: 0.6935228705406189, Epoch: 0\n", "Loss: 0.9742002487182617, Epoch: 0\n", "Loss: 0.6670038104057312, Epoch: 0\n", "Loss: 0.6938544511795044, Epoch: 0\n", "Loss: 0.6238456964492798, Epoch: 0\n", "Loss: 0.5920361280441284, Epoch: 0\n", "Loss: 0.6854602694511414, Epoch: 0\n", "Loss: 0.6296735405921936, Epoch: 0\n", "Loss: 0.6589140295982361, Epoch: 0\n", "Loss: 0.5916438698768616, Epoch: 0\n", "Loss: 0.870532214641571, Epoch: 0\n", "Loss: 0.5000392198562622, Epoch: 0\n", "Loss: 0.5296663641929626, Epoch: 0\n", "Loss: 0.731133222579956, Epoch: 0\n", "Loss: 0.5028390884399414, Epoch: 0\n", "Loss: 0.5200638175010681, Epoch: 0\n", "Loss: 0.4418269097805023, Epoch: 0\n", "Loss: 0.4933643937110901, Epoch: 0\n", "Loss: 0.5051903128623962, Epoch: 0\n", "Loss: 0.5060564279556274, Epoch: 0\n", "Loss: 0.46932217478752136, Epoch: 0\n", "Loss: 0.6307882070541382, Epoch: 0\n", "Loss: 0.4105120301246643, Epoch: 0\n", "Loss: 0.45868533849716187, Epoch: 0\n", "Loss: 0.4584932029247284, Epoch: 0\n", "Loss: 0.650614857673645, Epoch: 0\n", "Loss: 0.4539167582988739, Epoch: 0\n", "Loss: 0.4140841066837311, Epoch: 0\n", "Loss: 0.4211380183696747, Epoch: 0\n", "Loss: 0.4530402719974518, Epoch: 0\n", "Loss: 0.40992099046707153, Epoch: 0\n", "Loss: 0.45029035210609436, Epoch: 0\n", "Loss: 0.39612260460853577, Epoch: 0\n", "Loss: 0.43665429949760437, Epoch: 0\n", "Loss: 0.3842218220233917, Epoch: 0\n", "Loss: 0.38374099135398865, Epoch: 0\n", "Loss: 0.4136958122253418, Epoch: 0\n", "Loss: 0.44654661417007446, Epoch: 0\n", "Loss: 0.3930183947086334, Epoch: 0\n", "Loss: 0.40135490894317627, Epoch: 0\n", "Loss: 0.34777334332466125, Epoch: 0\n", "Loss: 0.4132601320743561, Epoch: 0\n", "Loss: 0.4124941825866699, Epoch: 0\n", "Loss: 0.42920032143592834, Epoch: 0\n", "Loss: 0.37013471126556396, Epoch: 0\n", "Loss: 0.3565217852592468, Epoch: 0\n", "Loss: 0.35290029644966125, Epoch: 0\n", "Loss: 0.3940131366252899, Epoch: 0\n", "Loss: 0.3507075309753418, Epoch: 0\n", "Loss: 0.340397447347641, Epoch: 0\n", "Loss: 0.34468895196914673, Epoch: 0\n", "Loss: 0.35861828923225403, Epoch: 0\n", "Loss: 0.331664502620697, Epoch: 0\n", "Loss: 0.3724385201931, Epoch: 0\n", "Loss: 0.4620945453643799, Epoch: 0\n", "Loss: 0.40686291456222534, Epoch: 0\n", "Loss: 0.34405651688575745, Epoch: 0\n", "Loss: 0.33751004934310913, Epoch: 0\n", "Loss: 0.33067846298217773, Epoch: 0\n", "Loss: 0.35637813806533813, Epoch: 0\n", "Loss: 0.33678290247917175, Epoch: 0\n", "Loss: 0.3399815857410431, Epoch: 0\n", "Loss: 0.3555727005004883, Epoch: 0\n", "Loss: 0.36221396923065186, Epoch: 0\n", "Loss: 0.31564000248908997, Epoch: 0\n", "Loss: 0.36513498425483704, Epoch: 1\n", "Loss: 0.38518428802490234, Epoch: 1\n", "Loss: 0.39113444089889526, Epoch: 1\n", "Loss: 0.31585872173309326, Epoch: 1\n", "Loss: 0.3462843894958496, Epoch: 1\n", "Loss: 0.31152087450027466, Epoch: 1\n", "Loss: 0.49244046211242676, Epoch: 1\n", "Loss: 0.3516445755958557, Epoch: 1\n", "Loss: 0.31429073214530945, Epoch: 1\n", "Loss: 0.3202592730522156, Epoch: 1\n", "Loss: 0.3343992531299591, Epoch: 1\n", "Loss: 0.3129790127277374, Epoch: 1\n", "Loss: 0.3035936653614044, Epoch: 1\n", "Loss: 0.3117291331291199, Epoch: 1\n", "Loss: 0.34716925024986267, Epoch: 1\n", "Loss: 0.2994453012943268, Epoch: 1\n", "Loss: 0.3337497413158417, Epoch: 1\n", "Loss: 0.29858383536338806, Epoch: 1\n", "Loss: 0.3567463457584381, Epoch: 1\n", "Loss: 0.2959181070327759, Epoch: 1\n", "Loss: 0.2913894057273865, Epoch: 1\n", "Loss: 0.31215083599090576, Epoch: 1\n", "Loss: 0.28617092967033386, Epoch: 1\n", "Loss: 0.31746673583984375, Epoch: 1\n", "Loss: 0.2842235565185547, Epoch: 1\n", "Loss: 0.3054017126560211, Epoch: 1\n", "Loss: 0.2990272045135498, Epoch: 1\n", "Loss: 0.2779654264450073, Epoch: 1\n", "Loss: 0.30389276146888733, Epoch: 1\n", "Loss: 0.3101218044757843, Epoch: 1\n", "Loss: 0.28804558515548706, Epoch: 1\n", "Loss: 0.34255534410476685, Epoch: 1\n", "Loss: 0.27916717529296875, Epoch: 1\n", "Loss: 0.2679106891155243, Epoch: 1\n", "Loss: 0.32168295979499817, Epoch: 1\n", "Loss: 0.2777051627635956, Epoch: 1\n", "Loss: 0.2755148708820343, Epoch: 1\n", "Loss: 0.2674615681171417, Epoch: 1\n", "Loss: 0.2992023825645447, Epoch: 1\n", "Loss: 0.2748924493789673, Epoch: 1\n", "Loss: 0.29646700620651245, Epoch: 1\n", "Loss: 0.28040811419487, Epoch: 1\n", "Loss: 0.30196037888526917, Epoch: 1\n", "Loss: 0.2965075969696045, Epoch: 1\n", "Loss: 0.26315343379974365, Epoch: 1\n", "Loss: 0.251770555973053, Epoch: 1\n", "Loss: 0.26690012216567993, Epoch: 1\n", "Loss: 0.2505774199962616, Epoch: 1\n", "Loss: 0.28790804743766785, Epoch: 1\n", "Loss: 0.2832382023334503, Epoch: 1\n", "Loss: 0.2817041575908661, Epoch: 1\n", "Loss: 0.28491726517677307, Epoch: 1\n", "Loss: 0.301067054271698, Epoch: 1\n", "Loss: 0.263874888420105, Epoch: 1\n", "Loss: 0.24870161712169647, Epoch: 1\n", "Loss: 0.2826177775859833, Epoch: 1\n", "Loss: 0.2598406672477722, Epoch: 1\n", "Loss: 0.25691503286361694, Epoch: 1\n", "Loss: 0.2716769576072693, Epoch: 1\n", "Loss: 0.2626866400241852, Epoch: 1\n", "Loss: 0.28078770637512207, Epoch: 1\n", "Loss: 0.36835792660713196, Epoch: 1\n", "Loss: 0.2390429675579071, Epoch: 1\n", "Loss: 0.2525367736816406, Epoch: 1\n", "Loss: 0.2904892861843109, Epoch: 1\n", "Loss: 0.2849578261375427, Epoch: 1\n", "Loss: 0.2514594793319702, Epoch: 1\n", "Loss: 0.2684769630432129, Epoch: 1\n", "Loss: 0.24913877248764038, Epoch: 1\n", "Loss: 0.26298603415489197, Epoch: 1\n", "Loss: 0.2606211006641388, Epoch: 1\n", "Loss: 0.24511288106441498, Epoch: 1\n", "Loss: 0.25391775369644165, Epoch: 1\n", "Loss: 0.28190696239471436, Epoch: 1\n", "Loss: 0.2384095937013626, Epoch: 1\n", "Loss: 0.25621259212493896, Epoch: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.23397196829319, Epoch: 1\n", "Loss: 0.2452157586812973, Epoch: 1\n", "Loss: 0.25013604760169983, Epoch: 1\n", "Loss: 0.24252459406852722, Epoch: 1\n", "Loss: 0.24472461640834808, Epoch: 1\n", "Loss: 0.3319304883480072, Epoch: 1\n", "Loss: 0.23765243589878082, Epoch: 1\n", "Loss: 0.24847112596035004, Epoch: 1\n", "Loss: 0.2331979125738144, Epoch: 1\n", "Loss: 0.26821067929267883, Epoch: 1\n", "Loss: 0.23324160277843475, Epoch: 1\n", "Loss: 0.2333941012620926, Epoch: 1\n", "Loss: 0.23415108025074005, Epoch: 1\n", "Loss: 0.2484195977449417, Epoch: 1\n", "Loss: 0.23526932299137115, Epoch: 1\n", "Loss: 0.22446300089359283, Epoch: 1\n", "Loss: 0.22437800467014313, Epoch: 1\n", "Loss: 0.2274080067873001, Epoch: 1\n", "Loss: 0.23755177855491638, Epoch: 1\n", "Loss: 0.21210730075836182, Epoch: 1\n", "Loss: 0.2223813533782959, Epoch: 1\n", "Loss: 0.22736792266368866, Epoch: 1\n", "Loss: 0.22397784888744354, Epoch: 1\n", "Loss: 0.2413712739944458, Epoch: 1\n", "Loss: 0.2235811948776245, Epoch: 1\n", "Loss: 0.24524028599262238, Epoch: 1\n", "Loss: 0.22623442113399506, Epoch: 1\n", "Loss: 0.24114780128002167, Epoch: 1\n", "Loss: 0.22427205741405487, Epoch: 1\n", "Loss: 0.2151176482439041, Epoch: 1\n", "Loss: 0.21807488799095154, Epoch: 1\n", "Loss: 0.2142992466688156, Epoch: 1\n", "Loss: 0.21087424457073212, Epoch: 1\n", "Loss: 0.2140428125858307, Epoch: 1\n", "Loss: 0.21240393817424774, Epoch: 1\n", "Loss: 0.20539182424545288, Epoch: 1\n", "Loss: 0.20411120355129242, Epoch: 1\n", "Loss: 0.21139389276504517, Epoch: 1\n", "Loss: 0.24850775301456451, Epoch: 1\n", "Loss: 0.21610818803310394, Epoch: 1\n", "Loss: 0.20004600286483765, Epoch: 1\n", "Loss: 0.19582688808441162, Epoch: 1\n", "Loss: 0.23555868864059448, Epoch: 1\n", "Loss: 0.20441941916942596, Epoch: 1\n", "Loss: 0.20177888870239258, Epoch: 1\n", "Loss: 0.19751212000846863, Epoch: 1\n", "Loss: 0.20864750444889069, Epoch: 1\n", "Loss: 0.20159509778022766, Epoch: 1\n", "Loss: 0.2005327045917511, Epoch: 1\n", "Loss: 0.1999230831861496, Epoch: 1\n", "Loss: 0.2040192186832428, Epoch: 1\n", "Loss: 0.19645045697689056, Epoch: 1\n", "Loss: 0.199203759431839, Epoch: 1\n", "Loss: 0.20018835365772247, Epoch: 1\n", "Loss: 0.19527064263820648, Epoch: 1\n", "Loss: 0.19556036591529846, Epoch: 1\n", "Loss: 0.19404765963554382, Epoch: 1\n", "Loss: 0.18973395228385925, Epoch: 1\n", "Loss: 0.19332443177700043, Epoch: 1\n", "Loss: 0.1943163424730301, Epoch: 1\n", "Loss: 0.22645233571529388, Epoch: 1\n", "Loss: 0.18985994160175323, Epoch: 1\n", "Loss: 0.18457040190696716, Epoch: 1\n", "Loss: 0.1887940615415573, Epoch: 1\n", "Loss: 0.184039905667305, Epoch: 1\n", "Loss: 0.18834181129932404, Epoch: 1\n", "Loss: 0.19446147978305817, Epoch: 1\n", "Loss: 0.1980251669883728, Epoch: 1\n", "Loss: 0.17985910177230835, Epoch: 1\n", "Loss: 0.1837856024503708, Epoch: 1\n", "Loss: 0.187068909406662, Epoch: 1\n", "Loss: 0.21451883018016815, Epoch: 1\n", "Loss: 0.18576252460479736, Epoch: 1\n", "Loss: 0.1763405203819275, Epoch: 1\n", "Loss: 0.1780242919921875, Epoch: 1\n", "Loss: 0.17861078679561615, Epoch: 1\n", "Loss: 0.1907288134098053, Epoch: 1\n", "Loss: 0.1792900711297989, Epoch: 1\n", "Loss: 0.18664227426052094, Epoch: 1\n", "Loss: 0.17793485522270203, Epoch: 1\n", "Loss: 0.17487949132919312, Epoch: 1\n", "Loss: 0.20927435159683228, Epoch: 1\n", "Loss: 0.17920136451721191, Epoch: 2\n", "Loss: 0.17775531113147736, Epoch: 2\n", "Loss: 0.17145593464374542, Epoch: 2\n", "Loss: 0.17465360462665558, Epoch: 2\n", "Loss: 0.18824218213558197, Epoch: 2\n", "Loss: 0.17750504612922668, Epoch: 2\n", "Loss: 0.19186504185199738, Epoch: 2\n", "Loss: 0.16394080221652985, Epoch: 2\n", "Loss: 0.1819714903831482, Epoch: 2\n", "Loss: 0.16088007390499115, Epoch: 2\n", "Loss: 0.18525062501430511, Epoch: 2\n", "Loss: 0.17226216197013855, Epoch: 2\n", "Loss: 0.17414595186710358, Epoch: 2\n", "Loss: 0.18195168673992157, Epoch: 2\n", "Loss: 0.17499519884586334, Epoch: 2\n", "Loss: 0.17727644741535187, Epoch: 2\n", "Loss: 0.17118574678897858, Epoch: 2\n", "Loss: 0.17873334884643555, Epoch: 2\n", "Loss: 0.1621459275484085, Epoch: 2\n", "Loss: 0.17832708358764648, Epoch: 2\n", "Loss: 0.16719700396060944, Epoch: 2\n", "Loss: 0.17284554243087769, Epoch: 2\n", "Loss: 0.16878937184810638, Epoch: 2\n", "Loss: 0.16516879200935364, Epoch: 2\n", "Loss: 0.16708654165267944, Epoch: 2\n", "Loss: 0.1831068992614746, Epoch: 2\n", "Loss: 0.18170510232448578, Epoch: 2\n", "Loss: 0.1666964292526245, Epoch: 2\n", "Loss: 0.18964578211307526, Epoch: 2\n", "Loss: 0.16580110788345337, Epoch: 2\n", "Loss: 0.1696055829524994, Epoch: 2\n", "Loss: 0.18262037634849548, Epoch: 2\n", "Loss: 0.16144578158855438, Epoch: 2\n", "Loss: 0.16571661829948425, Epoch: 2\n", "Loss: 0.16096633672714233, Epoch: 2\n", "Loss: 0.16165024042129517, Epoch: 2\n", "Loss: 0.16471704840660095, Epoch: 2\n", "Loss: 0.16848206520080566, Epoch: 2\n", "Loss: 0.17229224741458893, Epoch: 2\n", "Loss: 0.1488664597272873, Epoch: 2\n", "Loss: 0.16164709627628326, Epoch: 2\n", "Loss: 0.1660623997449875, Epoch: 2\n", "Loss: 0.16126243770122528, Epoch: 2\n", "Loss: 0.16471469402313232, Epoch: 2\n", "Loss: 0.15973404049873352, Epoch: 2\n", "Loss: 0.15950854122638702, Epoch: 2\n", "Loss: 0.14817242324352264, Epoch: 2\n", "Loss: 0.15562114119529724, Epoch: 2\n", "Loss: 0.15278998017311096, Epoch: 2\n", "Loss: 0.15652954578399658, Epoch: 2\n", "Loss: 0.1548881232738495, Epoch: 2\n", "Loss: 0.1577402502298355, Epoch: 2\n", "Loss: 0.1587544083595276, Epoch: 2\n", "Loss: 0.16029223799705505, Epoch: 2\n", "Loss: 0.15766172111034393, Epoch: 2\n", "Loss: 0.1566898226737976, Epoch: 2\n", "Loss: 0.15709587931632996, Epoch: 2\n", "Loss: 0.15407609939575195, Epoch: 2\n", "Loss: 0.15885479748249054, Epoch: 2\n", "Loss: 0.16054479777812958, Epoch: 2\n", "Loss: 0.16161851584911346, Epoch: 2\n", "Loss: 0.19581153988838196, Epoch: 2\n", "Loss: 0.1506919115781784, Epoch: 2\n", "Loss: 0.15983256697654724, Epoch: 2\n", "Loss: 0.1840815395116806, Epoch: 2\n", "Loss: 0.15997251868247986, Epoch: 2\n", "Loss: 0.158875972032547, Epoch: 2\n", "Loss: 0.1611945927143097, Epoch: 2\n", "Loss: 0.15246137976646423, Epoch: 2\n", "Loss: 0.1503104567527771, Epoch: 2\n", "Loss: 0.1955866515636444, Epoch: 2\n", "Loss: 0.14065603911876678, Epoch: 2\n", "Loss: 0.16270065307617188, Epoch: 2\n", "Loss: 0.1538032591342926, Epoch: 2\n", "Loss: 0.15558446943759918, Epoch: 2\n", "Loss: 0.177334263920784, Epoch: 2\n", "Loss: 0.15660960972309113, Epoch: 2\n", "Loss: 0.14623355865478516, Epoch: 2\n", "Loss: 0.1487160623073578, Epoch: 2\n", "Loss: 0.14769864082336426, Epoch: 2\n", "Loss: 0.15187585353851318, Epoch: 2\n", "Loss: 0.18615028262138367, Epoch: 2\n", "Loss: 0.15663014352321625, Epoch: 2\n", "Loss: 0.1567113846540451, Epoch: 2\n", "Loss: 0.1475883573293686, Epoch: 2\n", "Loss: 0.14988942444324493, Epoch: 2\n", "Loss: 0.1529097557067871, Epoch: 2\n", "Loss: 0.14954881370067596, Epoch: 2\n", "Loss: 0.14585697650909424, Epoch: 2\n", "Loss: 0.1537933349609375, Epoch: 2\n", "Loss: 0.16131213307380676, Epoch: 2\n", "Loss: 0.14611506462097168, Epoch: 2\n", "Loss: 0.15565554797649384, Epoch: 2\n", "Loss: 0.14928916096687317, Epoch: 2\n", "Loss: 0.18597358465194702, Epoch: 2\n", "Loss: 0.1498081088066101, Epoch: 2\n", "Loss: 0.1543579250574112, Epoch: 2\n", "Loss: 0.1561511605978012, Epoch: 2\n", "Loss: 0.15706948935985565, Epoch: 2\n", "Loss: 0.16255070269107819, Epoch: 2\n", "Loss: 0.15505479276180267, Epoch: 2\n", "Loss: 0.15793101489543915, Epoch: 2\n", "Loss: 0.15751025080680847, Epoch: 2\n", "Loss: 0.17496031522750854, Epoch: 2\n", "Loss: 0.15238243341445923, Epoch: 2\n", "Loss: 0.1480635404586792, Epoch: 2\n", "Loss: 0.16778671741485596, Epoch: 2\n", "Loss: 0.1499747931957245, Epoch: 2\n", "Loss: 0.14654149115085602, Epoch: 2\n", "Loss: 0.15334898233413696, Epoch: 2\n", "Loss: 0.14312013983726501, Epoch: 2\n", "Loss: 0.14889495074748993, Epoch: 2\n", "Loss: 0.15227057039737701, Epoch: 2\n", "Loss: 0.15047228336334229, Epoch: 2\n", "Loss: 0.1697094738483429, Epoch: 2\n", "Loss: 0.14746415615081787, Epoch: 2\n", "Loss: 0.14284475147724152, Epoch: 2\n", "Loss: 0.14408795535564423, Epoch: 2\n", "Loss: 0.1655958741903305, Epoch: 2\n", "Loss: 0.15247742831707, Epoch: 2\n", "Loss: 0.15246184170246124, Epoch: 2\n", "Loss: 0.1515989601612091, Epoch: 2\n", "Loss: 0.14632681012153625, Epoch: 2\n", "Loss: 0.15054377913475037, Epoch: 2\n", "Loss: 0.15041185915470123, Epoch: 2\n", "Loss: 0.15458422899246216, Epoch: 2\n", "Loss: 0.14498606324195862, Epoch: 2\n", "Loss: 0.1463392674922943, Epoch: 2\n", "Loss: 0.14906661212444305, Epoch: 2\n", "Loss: 0.15188321471214294, Epoch: 2\n", "Loss: 0.14843648672103882, Epoch: 2\n", "Loss: 0.1495736539363861, Epoch: 2\n", "Loss: 0.1508703976869583, Epoch: 2\n", "Loss: 0.1415024846792221, Epoch: 2\n", "Loss: 0.14888466894626617, Epoch: 2\n", "Loss: 0.14863824844360352, Epoch: 2\n", "Loss: 0.1804545372724533, Epoch: 2\n", "Loss: 0.14639806747436523, Epoch: 2\n", "Loss: 0.14789406955242157, Epoch: 2\n", "Loss: 0.1517217457294464, Epoch: 2\n", "Loss: 0.15233184397220612, Epoch: 2\n", "Loss: 0.14604727923870087, Epoch: 2\n", "Loss: 0.14814278483390808, Epoch: 2\n", "Loss: 0.14410676062107086, Epoch: 2\n", "Loss: 0.14756864309310913, Epoch: 2\n", "Loss: 0.15017764270305634, Epoch: 2\n", "Loss: 0.15275132656097412, Epoch: 2\n", "Loss: 0.15587218105793, Epoch: 2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.14909355342388153, Epoch: 2\n", "Loss: 0.14263498783111572, Epoch: 2\n", "Loss: 0.14612025022506714, Epoch: 2\n", "Loss: 0.147268146276474, Epoch: 2\n", "Loss: 0.14908260107040405, Epoch: 2\n", "Loss: 0.13998505473136902, Epoch: 2\n", "Loss: 0.14733323454856873, Epoch: 2\n", "Loss: 0.14932139217853546, Epoch: 2\n", "Loss: 0.14685842394828796, Epoch: 2\n", "Loss: 0.14514578878879547, Epoch: 2\n", "Loss: 0.14932207763195038, Epoch: 3\n", "Loss: 0.14646084606647491, Epoch: 3\n", "Loss: 0.14726775884628296, Epoch: 3\n", "Loss: 0.1431538462638855, Epoch: 3\n", "Loss: 0.15928220748901367, Epoch: 3\n", "Loss: 0.14593127369880676, Epoch: 3\n", "Loss: 0.16137444972991943, Epoch: 3\n", "Loss: 0.14118261635303497, Epoch: 3\n", "Loss: 0.1484784036874771, Epoch: 3\n", "Loss: 0.13823699951171875, Epoch: 3\n", "Loss: 0.1517200618982315, Epoch: 3\n", "Loss: 0.14846888184547424, Epoch: 3\n", "Loss: 0.14220844209194183, Epoch: 3\n", "Loss: 0.15510502457618713, Epoch: 3\n", "Loss: 0.14475488662719727, Epoch: 3\n", "Loss: 0.14743365347385406, Epoch: 3\n", "Loss: 0.1498083770275116, Epoch: 3\n", "Loss: 0.15431958436965942, Epoch: 3\n", "Loss: 0.1425355225801468, Epoch: 3\n", "Loss: 0.14830829203128815, Epoch: 3\n", "Loss: 0.14888258278369904, Epoch: 3\n", "Loss: 0.14516624808311462, Epoch: 3\n", "Loss: 0.14695410430431366, Epoch: 3\n", "Loss: 0.1333359032869339, Epoch: 3\n", "Loss: 0.14476899802684784, Epoch: 3\n", "Loss: 0.13775229454040527, Epoch: 3\n", "Loss: 0.14642129838466644, Epoch: 3\n", "Loss: 0.14050345122814178, Epoch: 3\n", "Loss: 0.14062468707561493, Epoch: 3\n", "Loss: 0.14814360439777374, Epoch: 3\n", "Loss: 0.14132237434387207, Epoch: 3\n", "Loss: 0.16100995242595673, Epoch: 3\n", "Loss: 0.14359651505947113, Epoch: 3\n", "Loss: 0.14262162148952484, Epoch: 3\n", "Loss: 0.13580752909183502, Epoch: 3\n", "Loss: 0.14163215458393097, Epoch: 3\n", "Loss: 0.15252500772476196, Epoch: 3\n", "Loss: 0.1420316845178604, Epoch: 3\n", "Loss: 0.1503005027770996, Epoch: 3\n", "Loss: 0.13034197688102722, Epoch: 3\n", "Loss: 0.14875632524490356, Epoch: 3\n", "Loss: 0.1441657543182373, Epoch: 3\n", "Loss: 0.14227809011936188, Epoch: 3\n", "Loss: 0.14413177967071533, Epoch: 3\n", "Loss: 0.14078521728515625, Epoch: 3\n", "Loss: 0.1412949413061142, Epoch: 3\n", "Loss: 0.14173543453216553, Epoch: 3\n", "Loss: 0.13746042549610138, Epoch: 3\n", "Loss: 0.1389954835176468, Epoch: 3\n", "Loss: 0.21365982294082642, Epoch: 3\n", "Loss: 0.13948048651218414, Epoch: 3\n", "Loss: 0.2131626307964325, Epoch: 3\n", "Loss: 0.31315305829048157, Epoch: 3\n", "Loss: 0.18220385909080505, Epoch: 3\n", "Loss: 0.44851815700531006, Epoch: 3\n", "Loss: 0.357102632522583, Epoch: 3\n", "Loss: 0.3697846233844757, Epoch: 3\n", "Loss: 0.2790517807006836, Epoch: 3\n", "Loss: 0.2529226243495941, Epoch: 3\n", "Loss: 0.23575733602046967, Epoch: 3\n", "Loss: 0.19847509264945984, Epoch: 3\n", "Loss: 0.1706792116165161, Epoch: 3\n", "Loss: 0.2386464625597, Epoch: 3\n", "Loss: 0.19907152652740479, Epoch: 3\n", "Loss: 0.21236853301525116, Epoch: 3\n", "Loss: 0.2005978673696518, Epoch: 3\n", "Loss: 0.2367972433567047, Epoch: 3\n", "Loss: 0.17748074233531952, Epoch: 3\n", "Loss: 0.1794709414243698, Epoch: 3\n", "Loss: 0.19476911425590515, Epoch: 3\n", "Loss: 0.23755262792110443, Epoch: 3\n", "Loss: 0.18582652509212494, Epoch: 3\n", "Loss: 0.1880793571472168, Epoch: 3\n", "Loss: 0.21453958749771118, Epoch: 3\n", "Loss: 0.17659011483192444, Epoch: 3\n", "Loss: 0.18438029289245605, Epoch: 3\n", "Loss: 0.16922736167907715, Epoch: 3\n", "Loss: 0.16512253880500793, Epoch: 3\n", "Loss: 0.16392676532268524, Epoch: 3\n", "Loss: 0.16824956238269806, Epoch: 3\n", "Loss: 0.17162224650382996, Epoch: 3\n", "Loss: 0.18260185420513153, Epoch: 3\n", "Loss: 0.1664547324180603, Epoch: 3\n", "Loss: 0.16220341622829437, Epoch: 3\n", "Loss: 0.16524864733219147, Epoch: 3\n", "Loss: 0.19387678802013397, Epoch: 3\n", "Loss: 0.15860770642757416, Epoch: 3\n", "Loss: 0.15186947584152222, Epoch: 3\n", "Loss: 0.1443750560283661, Epoch: 3\n", "Loss: 0.1585564911365509, Epoch: 3\n", "Loss: 0.17480601370334625, Epoch: 3\n", "Loss: 0.16876228153705597, Epoch: 3\n", "Loss: 0.15675309300422668, Epoch: 3\n", "Loss: 0.14876218140125275, Epoch: 3\n", "Loss: 0.19663038849830627, Epoch: 3\n", "Loss: 0.14992070198059082, Epoch: 3\n", "Loss: 0.166384756565094, Epoch: 3\n", "Loss: 0.15551204979419708, Epoch: 3\n", "Loss: 0.16162584722042084, Epoch: 3\n", "Loss: 0.16103392839431763, Epoch: 3\n", "Loss: 0.15038976073265076, Epoch: 3\n", "Loss: 0.1739073544740677, Epoch: 3\n", "Loss: 0.1523425579071045, Epoch: 3\n", "Loss: 0.1695767343044281, Epoch: 3\n", "Loss: 0.15175601840019226, Epoch: 3\n", "Loss: 0.14892224967479706, Epoch: 3\n", "Loss: 0.16292822360992432, Epoch: 3\n", "Loss: 0.14466966688632965, Epoch: 3\n", "Loss: 0.14638815820217133, Epoch: 3\n", "Loss: 0.14238741993904114, Epoch: 3\n", "Loss: 0.1467553973197937, Epoch: 3\n", "Loss: 0.1429072618484497, Epoch: 3\n", "Loss: 0.14860624074935913, Epoch: 3\n", "Loss: 0.14519362151622772, Epoch: 3\n", "Loss: 0.17172355949878693, Epoch: 3\n", "Loss: 0.14280566573143005, Epoch: 3\n", "Loss: 0.14112864434719086, Epoch: 3\n", "Loss: 0.14171293377876282, Epoch: 3\n", "Loss: 0.164906844496727, Epoch: 3\n", "Loss: 0.15102867782115936, Epoch: 3\n", "Loss: 0.14635801315307617, Epoch: 3\n", "Loss: 0.14682762324810028, Epoch: 3\n", "Loss: 0.14403878152370453, Epoch: 3\n", "Loss: 0.14384980499744415, Epoch: 3\n", "Loss: 0.15273280441761017, Epoch: 3\n", "Loss: 0.15245644748210907, Epoch: 3\n", "Loss: 0.14466692507266998, Epoch: 3\n", "Loss: 0.14336150884628296, Epoch: 3\n", "Loss: 0.14216336607933044, Epoch: 3\n", "Loss: 0.14766675233840942, Epoch: 3\n", "Loss: 0.1462327241897583, Epoch: 3\n", "Loss: 0.14618737995624542, Epoch: 3\n", "Loss: 0.14568249881267548, Epoch: 3\n", "Loss: 0.1387554109096527, Epoch: 3\n", "Loss: 0.13994747400283813, Epoch: 3\n", "Loss: 0.13912895321846008, Epoch: 3\n", "Loss: 0.1879383772611618, Epoch: 3\n", "Loss: 0.1433798223733902, Epoch: 3\n", "Loss: 0.14359787106513977, Epoch: 3\n", "Loss: 0.14942362904548645, Epoch: 3\n", "Loss: 0.14841613173484802, Epoch: 3\n", "Loss: 0.14370928704738617, Epoch: 3\n", "Loss: 0.1471216380596161, Epoch: 3\n", "Loss: 0.14228633046150208, Epoch: 3\n", "Loss: 0.14556394517421722, Epoch: 3\n", "Loss: 0.1456202119588852, Epoch: 3\n", "Loss: 0.1465320587158203, Epoch: 3\n", "Loss: 0.15671274065971375, Epoch: 3\n", "Loss: 0.14550618827342987, Epoch: 3\n", "Loss: 0.13757435977458954, Epoch: 3\n", "Loss: 0.14069128036499023, Epoch: 3\n", "Loss: 0.14401018619537354, Epoch: 3\n", "Loss: 0.14526163041591644, Epoch: 3\n", "Loss: 0.1361975520849228, Epoch: 3\n", "Loss: 0.1407855749130249, Epoch: 3\n", "Loss: 0.1457691341638565, Epoch: 3\n", "Loss: 0.145160973072052, Epoch: 3\n", "Loss: 0.1382274031639099, Epoch: 3\n", "Loss: 0.14645111560821533, Epoch: 4\n", "Loss: 0.14706067740917206, Epoch: 4\n", "Loss: 0.14455042779445648, Epoch: 4\n", "Loss: 0.13920459151268005, Epoch: 4\n", "Loss: 0.15357601642608643, Epoch: 4\n", "Loss: 0.14392921328544617, Epoch: 4\n", "Loss: 0.1616637110710144, Epoch: 4\n", "Loss: 0.13883942365646362, Epoch: 4\n", "Loss: 0.13806527853012085, Epoch: 4\n", "Loss: 0.13434778153896332, Epoch: 4\n", "Loss: 0.14757214486598969, Epoch: 4\n", "Loss: 0.14733949303627014, Epoch: 4\n", "Loss: 0.1386975646018982, Epoch: 4\n", "Loss: 0.15403154492378235, Epoch: 4\n", "Loss: 0.1409301906824112, Epoch: 4\n", "Loss: 0.143764927983284, Epoch: 4\n", "Loss: 0.142208531498909, Epoch: 4\n", "Loss: 0.14942139387130737, Epoch: 4\n", "Loss: 0.14091135561466217, Epoch: 4\n", "Loss: 0.14473268389701843, Epoch: 4\n", "Loss: 0.14563648402690887, Epoch: 4\n", "Loss: 0.13751356303691864, Epoch: 4\n", "Loss: 0.13772264122962952, Epoch: 4\n", "Loss: 0.13282646238803864, Epoch: 4\n", "Loss: 0.14611919224262238, Epoch: 4\n", "Loss: 0.13140395283699036, Epoch: 4\n", "Loss: 0.14141005277633667, Epoch: 4\n", "Loss: 0.13362324237823486, Epoch: 4\n", "Loss: 0.13560856878757477, Epoch: 4\n", "Loss: 0.1455288678407669, Epoch: 4\n", "Loss: 0.13728894293308258, Epoch: 4\n", "Loss: 0.16653631627559662, Epoch: 4\n", "Loss: 0.1408710777759552, Epoch: 4\n", "Loss: 0.13909664750099182, Epoch: 4\n", "Loss: 0.13380306959152222, Epoch: 4\n", "Loss: 0.1373494416475296, Epoch: 4\n", "Loss: 0.14459766447544098, Epoch: 4\n", "Loss: 0.13684412837028503, Epoch: 4\n", "Loss: 0.13902482390403748, Epoch: 4\n", "Loss: 0.1350281685590744, Epoch: 4\n", "Loss: 0.14071603119373322, Epoch: 4\n", "Loss: 0.13675397634506226, Epoch: 4\n", "Loss: 0.14031022787094116, Epoch: 4\n", "Loss: 0.13689054548740387, Epoch: 4\n", "Loss: 0.13759593665599823, Epoch: 4\n", "Loss: 0.1371559202671051, Epoch: 4\n", "Loss: 0.13737107813358307, Epoch: 4\n", "Loss: 0.13100741803646088, Epoch: 4\n", "Loss: 0.1358410120010376, Epoch: 4\n", "Loss: 0.1382497102022171, Epoch: 4\n", "Loss: 0.13645248115062714, Epoch: 4\n", "Loss: 0.1361786127090454, Epoch: 4\n", "Loss: 0.13501687347888947, Epoch: 4\n", "Loss: 0.1417567878961563, Epoch: 4\n", "Loss: 0.14279848337173462, Epoch: 4\n", "Loss: 0.1270349621772766, Epoch: 4\n", "Loss: 0.1460144817829132, Epoch: 4\n", "Loss: 0.1387779265642166, Epoch: 4\n", "Loss: 0.1424030065536499, Epoch: 4\n", "Loss: 0.14144709706306458, Epoch: 4\n", "Loss: 0.13189923763275146, Epoch: 4\n", "Loss: 0.18845082819461823, Epoch: 4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.12886185944080353, Epoch: 4\n", "Loss: 0.137564018368721, Epoch: 4\n", "Loss: 0.16131561994552612, Epoch: 4\n", "Loss: 0.1416606456041336, Epoch: 4\n", "Loss: 0.14148710668087006, Epoch: 4\n", "Loss: 0.14193467795848846, Epoch: 4\n", "Loss: 0.1345425546169281, Epoch: 4\n", "Loss: 0.11765842884778976, Epoch: 4\n", "Loss: 0.19954945147037506, Epoch: 4\n", "Loss: 0.12320726364850998, Epoch: 4\n", "Loss: 0.1290886402130127, Epoch: 4\n", "Loss: 0.13515831530094147, Epoch: 4\n", "Loss: 0.1401388943195343, Epoch: 4\n", "Loss: 0.17238909006118774, Epoch: 4\n", "Loss: 0.13727407157421112, Epoch: 4\n", "Loss: 0.1388918161392212, Epoch: 4\n", "Loss: 0.12661297619342804, Epoch: 4\n", "Loss: 0.1401134878396988, Epoch: 4\n", "Loss: 0.1431579738855362, Epoch: 4\n", "Loss: 0.19006170332431793, Epoch: 4\n", "Loss: 0.14119674265384674, Epoch: 4\n", "Loss: 0.1283990442752838, Epoch: 4\n", "Loss: 0.13079889118671417, Epoch: 4\n", "Loss: 0.1374884694814682, Epoch: 4\n", "Loss: 0.13765540719032288, Epoch: 4\n", "Loss: 0.11636094003915787, Epoch: 4\n", "Loss: 0.13191860914230347, Epoch: 4\n", "Loss: 0.13867181539535522, Epoch: 4\n", "Loss: 0.14593051373958588, Epoch: 4\n", "Loss: 0.132233127951622, Epoch: 4\n", "Loss: 0.13886763155460358, Epoch: 4\n", "Loss: 0.14014451205730438, Epoch: 4\n", "Loss: 0.17880161106586456, Epoch: 4\n", "Loss: 0.1320728063583374, Epoch: 4\n", "Loss: 0.1291753053665161, Epoch: 4\n", "Loss: 0.138149231672287, Epoch: 4\n", "Loss: 0.1335010826587677, Epoch: 4\n", "Loss: 0.1274981051683426, Epoch: 4\n", "Loss: 0.13696089386940002, Epoch: 4\n", "Loss: 0.13833221793174744, Epoch: 4\n", "Loss: 0.144839346408844, Epoch: 4\n", "Loss: 0.15894193947315216, Epoch: 4\n", "Loss: 0.15039022266864777, Epoch: 4\n", "Loss: 0.13267137110233307, Epoch: 4\n", "Loss: 0.1594497561454773, Epoch: 4\n", "Loss: 0.12029702961444855, Epoch: 4\n", "Loss: 0.12143255770206451, Epoch: 4\n", "Loss: 0.13340750336647034, Epoch: 4\n", "Loss: 0.1255176067352295, Epoch: 4\n", "Loss: 0.12960122525691986, Epoch: 4\n", "Loss: 0.13134732842445374, Epoch: 4\n", "Loss: 0.13162006437778473, Epoch: 4\n", "Loss: 0.17416930198669434, Epoch: 4\n", "Loss: 0.12281976640224457, Epoch: 4\n", "Loss: 0.1262483447790146, Epoch: 4\n", "Loss: 0.13074268400669098, Epoch: 4\n", "Loss: 0.14798814058303833, Epoch: 4\n", "Loss: 0.13607388734817505, Epoch: 4\n", "Loss: 0.1388680636882782, Epoch: 4\n", "Loss: 0.13348595798015594, Epoch: 4\n", "Loss: 0.1340729147195816, Epoch: 4\n", "Loss: 0.12516005337238312, Epoch: 4\n", "Loss: 0.138189896941185, Epoch: 4\n", "Loss: 0.13944798707962036, Epoch: 4\n", "Loss: 0.12049956619739532, Epoch: 4\n", "Loss: 0.1304030865430832, Epoch: 4\n", "Loss: 0.1271418333053589, Epoch: 4\n", "Loss: 0.129023939371109, Epoch: 4\n", "Loss: 0.13616971671581268, Epoch: 4\n", "Loss: 0.1296188086271286, Epoch: 4\n", "Loss: 0.13445493578910828, Epoch: 4\n", "Loss: 0.12608228623867035, Epoch: 4\n", "Loss: 0.1271897405385971, Epoch: 4\n", "Loss: 0.12908056378364563, Epoch: 4\n", "Loss: 0.16020454466342926, Epoch: 4\n", "Loss: 0.127391055226326, Epoch: 4\n", "Loss: 0.1296154111623764, Epoch: 4\n", "Loss: 0.13318374752998352, Epoch: 4\n", "Loss: 0.1316712349653244, Epoch: 4\n", "Loss: 0.12991346418857574, Epoch: 4\n", "Loss: 0.1342085301876068, Epoch: 4\n", "Loss: 0.12700845301151276, Epoch: 4\n", "Loss: 0.1335075944662094, Epoch: 4\n", "Loss: 0.12601137161254883, Epoch: 4\n", "Loss: 0.134405255317688, Epoch: 4\n", "Loss: 0.11723560839891434, Epoch: 4\n", "Loss: 0.12238363921642303, Epoch: 4\n", "Loss: 0.12179684638977051, Epoch: 4\n", "Loss: 0.12812206149101257, Epoch: 4\n", "Loss: 0.13349272310733795, Epoch: 4\n", "Loss: 0.13319160044193268, Epoch: 4\n", "Loss: 0.11952850967645645, Epoch: 4\n", "Loss: 0.1268928050994873, Epoch: 4\n", "Loss: 0.13243825733661652, Epoch: 4\n", "Loss: 0.13336098194122314, Epoch: 4\n", "Loss: 0.14038413763046265, Epoch: 4\n", "Loss: 0.1249275803565979, Epoch: 5\n", "Loss: 0.12349969148635864, Epoch: 5\n", "Loss: 0.14533382654190063, Epoch: 5\n", "Loss: 0.1261834055185318, Epoch: 5\n", "Loss: 0.14619743824005127, Epoch: 5\n", "Loss: 0.13299092650413513, Epoch: 5\n", "Loss: 0.1651540845632553, Epoch: 5\n", "Loss: 0.13018113374710083, Epoch: 5\n", "Loss: 0.12087367475032806, Epoch: 5\n", "Loss: 0.1217261552810669, Epoch: 5\n", "Loss: 0.13018827140331268, Epoch: 5\n", "Loss: 0.13194221258163452, Epoch: 5\n", "Loss: 0.12210225313901901, Epoch: 5\n", "Loss: 0.13475942611694336, Epoch: 5\n", "Loss: 0.13920359313488007, Epoch: 5\n", "Loss: 0.13548718392848969, Epoch: 5\n", "Loss: 0.12695132195949554, Epoch: 5\n", "Loss: 0.13285671174526215, Epoch: 5\n", "Loss: 0.12065783143043518, Epoch: 5\n", "Loss: 0.1276872307062149, Epoch: 5\n", "Loss: 0.13304223120212555, Epoch: 5\n", "Loss: 0.12204690277576447, Epoch: 5\n", "Loss: 0.12558265030384064, Epoch: 5\n", "Loss: 0.10881917923688889, Epoch: 5\n", "Loss: 0.12909428775310516, Epoch: 5\n", "Loss: 0.11289199441671371, Epoch: 5\n", "Loss: 0.12797780334949493, Epoch: 5\n", "Loss: 0.1226205825805664, Epoch: 5\n", "Loss: 0.11554985493421555, Epoch: 5\n", "Loss: 0.1257942020893097, Epoch: 5\n", "Loss: 0.13876764476299286, Epoch: 5\n", "Loss: 0.15175792574882507, Epoch: 5\n", "Loss: 0.12390465289354324, Epoch: 5\n", "Loss: 0.1710795760154724, Epoch: 5\n", "Loss: 0.1199038103222847, Epoch: 5\n", "Loss: 0.1343294382095337, Epoch: 5\n", "Loss: 0.1328621208667755, Epoch: 5\n", "Loss: 0.11477426439523697, Epoch: 5\n", "Loss: 0.12490951269865036, Epoch: 5\n", "Loss: 0.11478014290332794, Epoch: 5\n", "Loss: 0.13298383355140686, Epoch: 5\n", "Loss: 0.13167645037174225, Epoch: 5\n", "Loss: 0.12593436241149902, Epoch: 5\n", "Loss: 0.13232290744781494, Epoch: 5\n", "Loss: 0.13195490837097168, Epoch: 5\n", "Loss: 0.12542849779129028, Epoch: 5\n", "Loss: 0.13812699913978577, Epoch: 5\n", "Loss: 0.12022323906421661, Epoch: 5\n", "Loss: 0.12176539748907089, Epoch: 5\n", "Loss: 0.12790079414844513, Epoch: 5\n", "Loss: 0.12351775169372559, Epoch: 5\n", "Loss: 0.11888983845710754, Epoch: 5\n", "Loss: 0.12403333187103271, Epoch: 5\n", "Loss: 0.13656170666217804, Epoch: 5\n", "Loss: 0.13339126110076904, Epoch: 5\n", "Loss: 0.11131792515516281, Epoch: 5\n", "Loss: 0.12859481573104858, Epoch: 5\n", "Loss: 0.13500623404979706, Epoch: 5\n", "Loss: 0.12879762053489685, Epoch: 5\n", "Loss: 0.1277874857187271, Epoch: 5\n", "Loss: 0.11442019045352936, Epoch: 5\n", "Loss: 0.1521713137626648, Epoch: 5\n", "Loss: 0.1213955506682396, Epoch: 5\n", "Loss: 0.13298127055168152, Epoch: 5\n", "Loss: 0.14172470569610596, Epoch: 5\n", "Loss: 0.13587993383407593, Epoch: 5\n", "Loss: 0.14074204862117767, Epoch: 5\n", "Loss: 0.13420546054840088, Epoch: 5\n", "Loss: 0.12752939760684967, Epoch: 5\n", "Loss: 0.10668251663446426, Epoch: 5\n", "Loss: 0.19405558705329895, Epoch: 5\n", "Loss: 0.12440341711044312, Epoch: 5\n", "Loss: 0.12243395298719406, Epoch: 5\n", "Loss: 0.12695355713367462, Epoch: 5\n", "Loss: 0.13261005282402039, Epoch: 5\n", "Loss: 0.1808023750782013, Epoch: 5\n", "Loss: 0.12548168003559113, Epoch: 5\n", "Loss: 0.1334279477596283, Epoch: 5\n", "Loss: 0.13674396276474, Epoch: 5\n", "Loss: 0.13056345283985138, Epoch: 5\n", "Loss: 0.13382840156555176, Epoch: 5\n", "Loss: 0.15729926526546478, Epoch: 5\n", "Loss: 0.12456748634576797, Epoch: 5\n", "Loss: 0.1150851845741272, Epoch: 5\n", "Loss: 0.12835916876792908, Epoch: 5\n", "Loss: 0.17279915511608124, Epoch: 5\n", "Loss: 0.13193294405937195, Epoch: 5\n", "Loss: 0.10962967574596405, Epoch: 5\n", "Loss: 0.12050677835941315, Epoch: 5\n", "Loss: 0.17489789426326752, Epoch: 5\n", "Loss: 0.13095614314079285, Epoch: 5\n", "Loss: 0.12931302189826965, Epoch: 5\n", "Loss: 0.12235209345817566, Epoch: 5\n", "Loss: 0.1342233270406723, Epoch: 5\n", "Loss: 0.17001110315322876, Epoch: 5\n", "Loss: 0.123800128698349, Epoch: 5\n", "Loss: 0.20985817909240723, Epoch: 5\n", "Loss: 0.13027140498161316, Epoch: 5\n", "Loss: 0.11940208822488785, Epoch: 5\n", "Loss: 0.13204459846019745, Epoch: 5\n", "Loss: 0.13626925647258759, Epoch: 5\n", "Loss: 0.12456974387168884, Epoch: 5\n", "Loss: 0.1372535526752472, Epoch: 5\n", "Loss: 0.12558864057064056, Epoch: 5\n", "Loss: 0.14300382137298584, Epoch: 5\n", "Loss: 0.1280129849910736, Epoch: 5\n", "Loss: 0.12656547129154205, Epoch: 5\n", "Loss: 0.10765340924263, Epoch: 5\n", "Loss: 0.10957075655460358, Epoch: 5\n", "Loss: 0.14550140500068665, Epoch: 5\n", "Loss: 0.12373486161231995, Epoch: 5\n", "Loss: 0.11921708285808563, Epoch: 5\n", "Loss: 0.11897507309913635, Epoch: 5\n", "Loss: 0.11595134437084198, Epoch: 5\n", "Loss: 0.15143433213233948, Epoch: 5\n", "Loss: 0.10788632184267044, Epoch: 5\n", "Loss: 0.1180427297949791, Epoch: 5\n", "Loss: 0.12454649806022644, Epoch: 5\n", "Loss: 0.1388833224773407, Epoch: 5\n", "Loss: 0.12321362644433975, Epoch: 5\n", "Loss: 0.1299155354499817, Epoch: 5\n", "Loss: 0.12170170992612839, Epoch: 5\n", "Loss: 0.12160171568393707, Epoch: 5\n", "Loss: 0.11576078832149506, Epoch: 5\n", "Loss: 0.1474548727273941, Epoch: 5\n", "Loss: 0.12996402382850647, Epoch: 5\n", "Loss: 0.11152004450559616, Epoch: 5\n", "Loss: 0.11374114453792572, Epoch: 5\n", "Loss: 0.1124294102191925, Epoch: 5\n", "Loss: 0.11947697401046753, Epoch: 5\n", "Loss: 0.12320362776517868, Epoch: 5\n", "Loss: 0.12594015896320343, Epoch: 5\n", "Loss: 0.1282290816307068, Epoch: 5\n", "Loss: 0.11917021125555038, Epoch: 5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.11522866785526276, Epoch: 5\n", "Loss: 0.11686207354068756, Epoch: 5\n", "Loss: 0.13003912568092346, Epoch: 5\n", "Loss: 0.10768930613994598, Epoch: 5\n", "Loss: 0.11665098369121552, Epoch: 5\n", "Loss: 0.12027396261692047, Epoch: 5\n", "Loss: 0.1223018616437912, Epoch: 5\n", "Loss: 0.11921297758817673, Epoch: 5\n", "Loss: 0.126190647482872, Epoch: 5\n", "Loss: 0.11158805340528488, Epoch: 5\n", "Loss: 0.12258896976709366, Epoch: 5\n", "Loss: 0.10933675616979599, Epoch: 5\n", "Loss: 0.12323165684938431, Epoch: 5\n", "Loss: 0.09231279790401459, Epoch: 5\n", "Loss: 0.10996890813112259, Epoch: 5\n", "Loss: 0.10218094289302826, Epoch: 5\n", "Loss: 0.11884516477584839, Epoch: 5\n", "Loss: 0.11803596466779709, Epoch: 5\n", "Loss: 0.11715345084667206, Epoch: 5\n", "Loss: 0.10423947125673294, Epoch: 5\n", "Loss: 0.11898159980773926, Epoch: 5\n", "Loss: 0.1188461184501648, Epoch: 5\n", "Loss: 0.12247420847415924, Epoch: 5\n", "Loss: 0.10961997509002686, Epoch: 5\n", "Loss: 0.11624372005462646, Epoch: 6\n", "Loss: 0.10715224593877792, Epoch: 6\n", "Loss: 0.13775856792926788, Epoch: 6\n", "Loss: 0.11116379499435425, Epoch: 6\n", "Loss: 0.13297320902347565, Epoch: 6\n", "Loss: 0.12099210172891617, Epoch: 6\n", "Loss: 0.11810016632080078, Epoch: 6\n", "Loss: 0.12210290879011154, Epoch: 6\n", "Loss: 0.1103905439376831, Epoch: 6\n", "Loss: 0.10696553438901901, Epoch: 6\n", "Loss: 0.11589460074901581, Epoch: 6\n", "Loss: 0.11533137410879135, Epoch: 6\n", "Loss: 0.10772214829921722, Epoch: 6\n", "Loss: 0.12689828872680664, Epoch: 6\n", "Loss: 0.12170947343111038, Epoch: 6\n", "Loss: 0.1141921728849411, Epoch: 6\n", "Loss: 0.11483687907457352, Epoch: 6\n", "Loss: 0.12149053812026978, Epoch: 6\n", "Loss: 0.10325966775417328, Epoch: 6\n", "Loss: 0.11688613146543503, Epoch: 6\n", "Loss: 0.11686090379953384, Epoch: 6\n", "Loss: 0.10853508859872818, Epoch: 6\n", "Loss: 0.12138164043426514, Epoch: 6\n", "Loss: 0.10637712478637695, Epoch: 6\n", "Loss: 0.11518768966197968, Epoch: 6\n", "Loss: 0.1005590632557869, Epoch: 6\n", "Loss: 0.10979194939136505, Epoch: 6\n", "Loss: 0.10970509797334671, Epoch: 6\n", "Loss: 0.10047701746225357, Epoch: 6\n", "Loss: 0.11055599898099899, Epoch: 6\n", "Loss: 0.1114952564239502, Epoch: 6\n", "Loss: 0.1298879086971283, Epoch: 6\n", "Loss: 0.1118747889995575, Epoch: 6\n", "Loss: 0.11944448947906494, Epoch: 6\n", "Loss: 0.11376014351844788, Epoch: 6\n", "Loss: 0.13856884837150574, Epoch: 6\n", "Loss: 0.12177485227584839, Epoch: 6\n", "Loss: 0.09694704413414001, Epoch: 6\n", "Loss: 0.11368582397699356, Epoch: 6\n", "Loss: 0.11257602274417877, Epoch: 6\n", "Loss: 0.1129927709698677, Epoch: 6\n", "Loss: 0.11432743817567825, Epoch: 6\n", "Loss: 0.11266694217920303, Epoch: 6\n", "Loss: 0.10029543936252594, Epoch: 6\n", "Loss: 0.10305653512477875, Epoch: 6\n", "Loss: 0.1084495559334755, Epoch: 6\n", "Loss: 0.1262296736240387, Epoch: 6\n", "Loss: 0.13879963755607605, Epoch: 6\n", "Loss: 0.11159215122461319, Epoch: 6\n", "Loss: 0.10721106827259064, Epoch: 6\n", "Loss: 0.10601997375488281, Epoch: 6\n", "Loss: 0.10585113614797592, Epoch: 6\n", "Loss: 0.12005317956209183, Epoch: 6\n", "Loss: 0.11511770635843277, Epoch: 6\n", "Loss: 0.12122543901205063, Epoch: 6\n", "Loss: 0.0961105078458786, Epoch: 6\n", "Loss: 0.11446699500083923, Epoch: 6\n", "Loss: 0.11680401861667633, Epoch: 6\n", "Loss: 0.11895857751369476, Epoch: 6\n", "Loss: 0.11428070813417435, Epoch: 6\n", "Loss: 0.10440615564584732, Epoch: 6\n", "Loss: 0.10592535138130188, Epoch: 6\n", "Loss: 0.1098354384303093, Epoch: 6\n", "Loss: 0.11369583010673523, Epoch: 6\n", "Loss: 0.11765711009502411, Epoch: 6\n", "Loss: 0.12683521211147308, Epoch: 6\n", "Loss: 0.14331921935081482, Epoch: 6\n", "Loss: 0.12109503895044327, Epoch: 6\n", "Loss: 0.19002178311347961, Epoch: 6\n", "Loss: 0.132974311709404, Epoch: 6\n", "Loss: 0.18188707530498505, Epoch: 6\n", "Loss: 0.1156173050403595, Epoch: 6\n", "Loss: 0.21222077310085297, Epoch: 6\n", "Loss: 0.11632455885410309, Epoch: 6\n", "Loss: 0.11654199659824371, Epoch: 6\n", "Loss: 0.2229778915643692, Epoch: 6\n", "Loss: 0.13579709827899933, Epoch: 6\n", "Loss: 0.13772152364253998, Epoch: 6\n", "Loss: 0.21927736699581146, Epoch: 6\n", "Loss: 0.14530208706855774, Epoch: 6\n", "Loss: 0.12683066725730896, Epoch: 6\n", "Loss: 0.12871938943862915, Epoch: 6\n", "Loss: 0.12640881538391113, Epoch: 6\n", "Loss: 0.12194661796092987, Epoch: 6\n", "Loss: 0.11213715374469757, Epoch: 6\n", "Loss: 0.13071177899837494, Epoch: 6\n", "Loss: 0.14279155433177948, Epoch: 6\n", "Loss: 0.09821438789367676, Epoch: 6\n", "Loss: 0.11000728607177734, Epoch: 6\n", "Loss: 0.13279211521148682, Epoch: 6\n", "Loss: 0.13721898198127747, Epoch: 6\n", "Loss: 0.11367487162351608, Epoch: 6\n", "Loss: 0.13350757956504822, Epoch: 6\n", "Loss: 0.1248038113117218, Epoch: 6\n", "Loss: 0.15275642275810242, Epoch: 6\n", "Loss: 0.1401827335357666, Epoch: 6\n", "Loss: 0.1166832372546196, Epoch: 6\n", "Loss: 0.1256875991821289, Epoch: 6\n", "Loss: 0.11419999599456787, Epoch: 6\n", "Loss: 0.11006093770265579, Epoch: 6\n", "Loss: 0.11871582269668579, Epoch: 6\n", "Loss: 0.11190492659807205, Epoch: 6\n", "Loss: 0.12537062168121338, Epoch: 6\n", "Loss: 0.08388262987136841, Epoch: 6\n", "Loss: 0.1354859173297882, Epoch: 6\n", "Loss: 0.1201799213886261, Epoch: 6\n", "Loss: 0.08500127494335175, Epoch: 6\n", "Loss: 0.09311848878860474, Epoch: 6\n", "Loss: 0.09804543852806091, Epoch: 6\n", "Loss: 0.13395635783672333, Epoch: 6\n", "Loss: 0.11851917952299118, Epoch: 6\n", "Loss: 0.10380667448043823, Epoch: 6\n", "Loss: 0.10534995794296265, Epoch: 6\n", "Loss: 0.10853283107280731, Epoch: 6\n", "Loss: 0.14459960162639618, Epoch: 6\n", "Loss: 0.09134180843830109, Epoch: 6\n", "Loss: 0.10183943808078766, Epoch: 6\n", "Loss: 0.11426414549350739, Epoch: 6\n", "Loss: 0.11680352687835693, Epoch: 6\n", "Loss: 0.1100248321890831, Epoch: 6\n", "Loss: 0.11217734217643738, Epoch: 6\n", "Loss: 0.11318906396627426, Epoch: 6\n", "Loss: 0.11201858520507812, Epoch: 6\n", "Loss: 0.1053985059261322, Epoch: 6\n", "Loss: 0.12794645130634308, Epoch: 6\n", "Loss: 0.11207936704158783, Epoch: 6\n", "Loss: 0.10350680351257324, Epoch: 6\n", "Loss: 0.10383594036102295, Epoch: 6\n", "Loss: 0.10246621817350388, Epoch: 6\n", "Loss: 0.10486087203025818, Epoch: 6\n", "Loss: 0.11843543499708176, Epoch: 6\n", "Loss: 0.10555055737495422, Epoch: 6\n", "Loss: 0.11247757077217102, Epoch: 6\n", "Loss: 0.11798767000436783, Epoch: 6\n", "Loss: 0.10157867521047592, Epoch: 6\n", "Loss: 0.10155266523361206, Epoch: 6\n", "Loss: 0.10669185221195221, Epoch: 6\n", "Loss: 0.09663646668195724, Epoch: 6\n", "Loss: 0.11037393659353256, Epoch: 6\n", "Loss: 0.10974664241075516, Epoch: 6\n", "Loss: 0.10823851823806763, Epoch: 6\n", "Loss: 0.1105186939239502, Epoch: 6\n", "Loss: 0.10878249257802963, Epoch: 6\n", "Loss: 0.0992928072810173, Epoch: 6\n", "Loss: 0.11492595076560974, Epoch: 6\n", "Loss: 0.10114570707082748, Epoch: 6\n", "Loss: 0.11305642127990723, Epoch: 6\n", "Loss: 0.08154793083667755, Epoch: 6\n", "Loss: 0.09805180877447128, Epoch: 6\n", "Loss: 0.08986265957355499, Epoch: 6\n", "Loss: 0.1070898175239563, Epoch: 6\n", "Loss: 0.10565295815467834, Epoch: 6\n", "Loss: 0.10418456047773361, Epoch: 6\n", "Loss: 0.09737840294837952, Epoch: 6\n", "Loss: 0.1021442785859108, Epoch: 6\n", "Loss: 0.10725658386945724, Epoch: 6\n", "Loss: 0.10730786621570587, Epoch: 6\n", "Loss: 0.09617894887924194, Epoch: 6\n", "Loss: 0.11127995699644089, Epoch: 7\n", "Loss: 0.0874585509300232, Epoch: 7\n", "Loss: 0.12249134480953217, Epoch: 7\n", "Loss: 0.1020079031586647, Epoch: 7\n", "Loss: 0.11316166073083878, Epoch: 7\n", "Loss: 0.10774492472410202, Epoch: 7\n", "Loss: 0.09285763651132584, Epoch: 7\n", "Loss: 0.1098235622048378, Epoch: 7\n", "Loss: 0.09434773027896881, Epoch: 7\n", "Loss: 0.09606831520795822, Epoch: 7\n", "Loss: 0.10440029948949814, Epoch: 7\n", "Loss: 0.10212397575378418, Epoch: 7\n", "Loss: 0.09842287749052048, Epoch: 7\n", "Loss: 0.11679971218109131, Epoch: 7\n", "Loss: 0.10654377192258835, Epoch: 7\n", "Loss: 0.10638795793056488, Epoch: 7\n", "Loss: 0.1004040390253067, Epoch: 7\n", "Loss: 0.1108088567852974, Epoch: 7\n", "Loss: 0.09905845671892166, Epoch: 7\n", "Loss: 0.1053299531340599, Epoch: 7\n", "Loss: 0.1049308255314827, Epoch: 7\n", "Loss: 0.09539797902107239, Epoch: 7\n", "Loss: 0.096033975481987, Epoch: 7\n", "Loss: 0.09786777198314667, Epoch: 7\n", "Loss: 0.10437578707933426, Epoch: 7\n", "Loss: 0.08823202550411224, Epoch: 7\n", "Loss: 0.09576096385717392, Epoch: 7\n", "Loss: 0.09627911448478699, Epoch: 7\n", "Loss: 0.0886673852801323, Epoch: 7\n", "Loss: 0.10158203542232513, Epoch: 7\n", "Loss: 0.09936793148517609, Epoch: 7\n", "Loss: 0.10796260088682175, Epoch: 7\n", "Loss: 0.10200068354606628, Epoch: 7\n", "Loss: 0.1036141887307167, Epoch: 7\n", "Loss: 0.10264419764280319, Epoch: 7\n", "Loss: 0.1228143498301506, Epoch: 7\n", "Loss: 0.1109018325805664, Epoch: 7\n", "Loss: 0.08281095325946808, Epoch: 7\n", "Loss: 0.09651509672403336, Epoch: 7\n", "Loss: 0.10483517497777939, Epoch: 7\n", "Loss: 0.10054583847522736, Epoch: 7\n", "Loss: 0.09840130060911179, Epoch: 7\n", "Loss: 0.10182977467775345, Epoch: 7\n", "Loss: 0.09018167108297348, Epoch: 7\n", "Loss: 0.08957422524690628, Epoch: 7\n", "Loss: 0.09855228662490845, Epoch: 7\n", "Loss: 0.11381145566701889, Epoch: 7\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.09318786859512329, Epoch: 7\n", "Loss: 0.09626708924770355, Epoch: 7\n", "Loss: 0.09697861224412918, Epoch: 7\n", "Loss: 0.09386548399925232, Epoch: 7\n", "Loss: 0.09933487325906754, Epoch: 7\n", "Loss: 0.08954917639493942, Epoch: 7\n", "Loss: 0.09967973083257675, Epoch: 7\n", "Loss: 0.10813768208026886, Epoch: 7\n", "Loss: 0.08184698224067688, Epoch: 7\n", "Loss: 0.10852080583572388, Epoch: 7\n", "Loss: 0.10763221979141235, Epoch: 7\n", "Loss: 0.11019491404294968, Epoch: 7\n", "Loss: 0.1040777862071991, Epoch: 7\n", "Loss: 0.0847632959485054, Epoch: 7\n", "Loss: 0.08171959221363068, Epoch: 7\n", "Loss: 0.08776959776878357, Epoch: 7\n", "Loss: 0.10209298133850098, Epoch: 7\n", "Loss: 0.08908098191022873, Epoch: 7\n", "Loss: 0.10392419248819351, Epoch: 7\n", "Loss: 0.09996441006660461, Epoch: 7\n", "Loss: 0.11016766726970673, Epoch: 7\n", "Loss: 0.10478966683149338, Epoch: 7\n", "Loss: 0.07866249978542328, Epoch: 7\n", "Loss: 0.13597208261489868, Epoch: 7\n", "Loss: 0.1079997643828392, Epoch: 7\n", "Loss: 0.08729257434606552, Epoch: 7\n", "Loss: 0.10136373341083527, Epoch: 7\n", "Loss: 0.09196136891841888, Epoch: 7\n", "Loss: 0.1400473266839981, Epoch: 7\n", "Loss: 0.09910447895526886, Epoch: 7\n", "Loss: 0.10202774405479431, Epoch: 7\n", "Loss: 0.0815005972981453, Epoch: 7\n", "Loss: 0.10101136565208435, Epoch: 7\n", "Loss: 0.0984138771891594, Epoch: 7\n", "Loss: 0.09421591460704803, Epoch: 7\n", "Loss: 0.10190987586975098, Epoch: 7\n", "Loss: 0.08138076961040497, Epoch: 7\n", "Loss: 0.08836308121681213, Epoch: 7\n", "Loss: 0.10892144590616226, Epoch: 7\n", "Loss: 0.10090276598930359, Epoch: 7\n", "Loss: 0.07948984205722809, Epoch: 7\n", "Loss: 0.09559652209281921, Epoch: 7\n", "Loss: 0.10386673361063004, Epoch: 7\n", "Loss: 0.10217111557722092, Epoch: 7\n", "Loss: 0.09979840368032455, Epoch: 7\n", "Loss: 0.09837999939918518, Epoch: 7\n", "Loss: 0.10334828495979309, Epoch: 7\n", "Loss: 0.10868719965219498, Epoch: 7\n", "Loss: 0.0996246188879013, Epoch: 7\n", "Loss: 0.09603465348482132, Epoch: 7\n", "Loss: 0.0976654514670372, Epoch: 7\n", "Loss: 0.09486246854066849, Epoch: 7\n", "Loss: 0.08585795760154724, Epoch: 7\n", "Loss: 0.10285493731498718, Epoch: 7\n", "Loss: 0.09559769928455353, Epoch: 7\n", "Loss: 0.10589518398046494, Epoch: 7\n", "Loss: 0.0665736198425293, Epoch: 7\n", "Loss: 0.12515269219875336, Epoch: 7\n", "Loss: 0.10746419429779053, Epoch: 7\n", "Loss: 0.06850691139698029, Epoch: 7\n", "Loss: 0.07568707317113876, Epoch: 7\n", "Loss: 0.08142060041427612, Epoch: 7\n", "Loss: 0.14092078804969788, Epoch: 7\n", "Loss: 0.10802359133958817, Epoch: 7\n", "Loss: 0.09362595528364182, Epoch: 7\n", "Loss: 0.09656472504138947, Epoch: 7\n", "Loss: 0.09624163806438446, Epoch: 7\n", "Loss: 0.1221676841378212, Epoch: 7\n", "Loss: 0.08465592563152313, Epoch: 7\n", "Loss: 0.09251771122217178, Epoch: 7\n", "Loss: 0.1012636050581932, Epoch: 7\n", "Loss: 0.08949314802885056, Epoch: 7\n", "Loss: 0.09692538529634476, Epoch: 7\n", "Loss: 0.10708724707365036, Epoch: 7\n", "Loss: 0.10580960661172867, Epoch: 7\n", "Loss: 0.10133131593465805, Epoch: 7\n", "Loss: 0.09440037608146667, Epoch: 7\n", "Loss: 0.11331550776958466, Epoch: 7\n", "Loss: 0.09957834333181381, Epoch: 7\n", "Loss: 0.08154382556676865, Epoch: 7\n", "Loss: 0.09030596911907196, Epoch: 7\n", "Loss: 0.09509352594614029, Epoch: 7\n", "Loss: 0.09161731600761414, Epoch: 7\n", "Loss: 0.10449625551700592, Epoch: 7\n", "Loss: 0.09905510395765305, Epoch: 7\n", "Loss: 0.10003232210874557, Epoch: 7\n", "Loss: 0.10228797793388367, Epoch: 7\n", "Loss: 0.08921154588460922, Epoch: 7\n", "Loss: 0.09422476589679718, Epoch: 7\n", "Loss: 0.09613534063100815, Epoch: 7\n", "Loss: 0.09139202535152435, Epoch: 7\n", "Loss: 0.09728746861219406, Epoch: 7\n", "Loss: 0.0973246768116951, Epoch: 7\n", "Loss: 0.09548674523830414, Epoch: 7\n", "Loss: 0.10061587393283844, Epoch: 7\n", "Loss: 0.09739266335964203, Epoch: 7\n", "Loss: 0.08884819597005844, Epoch: 7\n", "Loss: 0.09968846291303635, Epoch: 7\n", "Loss: 0.08458088338375092, Epoch: 7\n", "Loss: 0.10188835114240646, Epoch: 7\n", "Loss: 0.0672396719455719, Epoch: 7\n", "Loss: 0.08573003113269806, Epoch: 7\n", "Loss: 0.07837706059217453, Epoch: 7\n", "Loss: 0.10032759606838226, Epoch: 7\n", "Loss: 0.09497064352035522, Epoch: 7\n", "Loss: 0.0931612178683281, Epoch: 7\n", "Loss: 0.0899437889456749, Epoch: 7\n", "Loss: 0.09274802356958389, Epoch: 7\n", "Loss: 0.09666411578655243, Epoch: 7\n", "Loss: 0.09964553266763687, Epoch: 7\n", "Loss: 0.08965051174163818, Epoch: 7\n", "Loss: 0.09531083703041077, Epoch: 8\n", "Loss: 0.0764368548989296, Epoch: 8\n", "Loss: 0.09976893663406372, Epoch: 8\n", "Loss: 0.09121975302696228, Epoch: 8\n", "Loss: 0.09898321330547333, Epoch: 8\n", "Loss: 0.09821636974811554, Epoch: 8\n", "Loss: 0.08225860446691513, Epoch: 8\n", "Loss: 0.10347147285938263, Epoch: 8\n", "Loss: 0.08411657810211182, Epoch: 8\n", "Loss: 0.08513495326042175, Epoch: 8\n", "Loss: 0.09337693452835083, Epoch: 8\n", "Loss: 0.09445364028215408, Epoch: 8\n", "Loss: 0.08690741658210754, Epoch: 8\n", "Loss: 0.10567010939121246, Epoch: 8\n", "Loss: 0.09979277849197388, Epoch: 8\n", "Loss: 0.10393866151571274, Epoch: 8\n", "Loss: 0.08208710700273514, Epoch: 8\n", "Loss: 0.10648505389690399, Epoch: 8\n", "Loss: 0.09276064485311508, Epoch: 8\n", "Loss: 0.10089381039142609, Epoch: 8\n", "Loss: 0.0965060368180275, Epoch: 8\n", "Loss: 0.0881606936454773, Epoch: 8\n", "Loss: 0.08871510624885559, Epoch: 8\n", "Loss: 0.09277530014514923, Epoch: 8\n", "Loss: 0.09860555082559586, Epoch: 8\n", "Loss: 0.07966770976781845, Epoch: 8\n", "Loss: 0.09030231833457947, Epoch: 8\n", "Loss: 0.0851370170712471, Epoch: 8\n", "Loss: 0.08056827634572983, Epoch: 8\n", "Loss: 0.09375553578138351, Epoch: 8\n", "Loss: 0.0939326360821724, Epoch: 8\n", "Loss: 0.09433168917894363, Epoch: 8\n", "Loss: 0.09604447335004807, Epoch: 8\n", "Loss: 0.09176722913980484, Epoch: 8\n", "Loss: 0.09819795936346054, Epoch: 8\n", "Loss: 0.10539628565311432, Epoch: 8\n", "Loss: 0.10229707509279251, Epoch: 8\n", "Loss: 0.07657080888748169, Epoch: 8\n", "Loss: 0.09068753570318222, Epoch: 8\n", "Loss: 0.0944778323173523, Epoch: 8\n", "Loss: 0.09684833884239197, Epoch: 8\n", "Loss: 0.09293913841247559, Epoch: 8\n", "Loss: 0.10148358345031738, Epoch: 8\n", "Loss: 0.08338888734579086, Epoch: 8\n", "Loss: 0.08214575797319412, Epoch: 8\n", "Loss: 0.08863586187362671, Epoch: 8\n", "Loss: 0.10475130379199982, Epoch: 8\n", "Loss: 0.10403673350811005, Epoch: 8\n", "Loss: 0.08990421146154404, Epoch: 8\n", "Loss: 0.08858269453048706, Epoch: 8\n", "Loss: 0.08760341256856918, Epoch: 8\n", "Loss: 0.09343436360359192, Epoch: 8\n", "Loss: 0.0796404555439949, Epoch: 8\n", "Loss: 0.09368477016687393, Epoch: 8\n", "Loss: 0.0998200923204422, Epoch: 8\n", "Loss: 0.07597582042217255, Epoch: 8\n", "Loss: 0.10604184120893478, Epoch: 8\n", "Loss: 0.09754043072462082, Epoch: 8\n", "Loss: 0.10065465420484543, Epoch: 8\n", "Loss: 0.10647395253181458, Epoch: 8\n", "Loss: 0.08068972080945969, Epoch: 8\n", "Loss: 0.16115640103816986, Epoch: 8\n", "Loss: 0.11444346606731415, Epoch: 8\n", "Loss: 0.10705339163541794, Epoch: 8\n", "Loss: 0.09086104482412338, Epoch: 8\n", "Loss: 0.10477293282747269, Epoch: 8\n", "Loss: 0.09490729123353958, Epoch: 8\n", "Loss: 0.11211740970611572, Epoch: 8\n", "Loss: 0.10389804095029831, Epoch: 8\n", "Loss: 0.08844345808029175, Epoch: 8\n", "Loss: 0.21717138588428497, Epoch: 8\n", "Loss: 0.10999349504709244, Epoch: 8\n", "Loss: 0.08422992378473282, Epoch: 8\n", "Loss: 0.101418137550354, Epoch: 8\n", "Loss: 0.09684649854898453, Epoch: 8\n", "Loss: 0.1724022775888443, Epoch: 8\n", "Loss: 0.1314866691827774, Epoch: 8\n", "Loss: 0.09903855621814728, Epoch: 8\n", "Loss: 0.11843590438365936, Epoch: 8\n", "Loss: 0.12226806581020355, Epoch: 8\n", "Loss: 0.13927863538265228, Epoch: 8\n", "Loss: 0.10693665593862534, Epoch: 8\n", "Loss: 0.10284698754549026, Epoch: 8\n", "Loss: 0.09007545560598373, Epoch: 8\n", "Loss: 0.13318724930286407, Epoch: 8\n", "Loss: 0.11283812671899796, Epoch: 8\n", "Loss: 0.13335789740085602, Epoch: 8\n", "Loss: 0.08382488787174225, Epoch: 8\n", "Loss: 0.09779892861843109, Epoch: 8\n", "Loss: 0.11922663450241089, Epoch: 8\n", "Loss: 0.10391080379486084, Epoch: 8\n", "Loss: 0.09241219609975815, Epoch: 8\n", "Loss: 0.0914870947599411, Epoch: 8\n", "Loss: 0.10276245325803757, Epoch: 8\n", "Loss: 0.11850879341363907, Epoch: 8\n", "Loss: 0.11146795749664307, Epoch: 8\n", "Loss: 0.13590680062770844, Epoch: 8\n", "Loss: 0.10979834944009781, Epoch: 8\n", "Loss: 0.09215784817934036, Epoch: 8\n", "Loss: 0.10343850404024124, Epoch: 8\n", "Loss: 0.10027733445167542, Epoch: 8\n", "Loss: 0.09378498047590256, Epoch: 8\n", "Loss: 0.11168068647384644, Epoch: 8\n", "Loss: 0.0735282376408577, Epoch: 8\n", "Loss: 0.11101522296667099, Epoch: 8\n", "Loss: 0.09046553075313568, Epoch: 8\n", "Loss: 0.049216534942388535, Epoch: 8\n", "Loss: 0.07616625726222992, Epoch: 8\n", "Loss: 0.08197750896215439, Epoch: 8\n", "Loss: 0.11903025954961777, Epoch: 8\n", "Loss: 0.11152675747871399, Epoch: 8\n", "Loss: 0.09254954010248184, Epoch: 8\n", "Loss: 0.10088139027357101, Epoch: 8\n", "Loss: 0.09094111621379852, Epoch: 8\n", "Loss: 0.11987733095884323, Epoch: 8\n", "Loss: 0.07824293524026871, Epoch: 8\n", "Loss: 0.08746776729822159, Epoch: 8\n", "Loss: 0.09556399285793304, Epoch: 8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.10032295435667038, Epoch: 8\n", "Loss: 0.0936238169670105, Epoch: 8\n", "Loss: 0.09704777598381042, Epoch: 8\n", "Loss: 0.0960117056965828, Epoch: 8\n", "Loss: 0.09698610007762909, Epoch: 8\n", "Loss: 0.08561122417449951, Epoch: 8\n", "Loss: 0.11067008972167969, Epoch: 8\n", "Loss: 0.09560175240039825, Epoch: 8\n", "Loss: 0.07673484086990356, Epoch: 8\n", "Loss: 0.08287633210420609, Epoch: 8\n", "Loss: 0.0930718407034874, Epoch: 8\n", "Loss: 0.08970781415700912, Epoch: 8\n", "Loss: 0.1007469967007637, Epoch: 8\n", "Loss: 0.09398171305656433, Epoch: 8\n", "Loss: 0.09208820015192032, Epoch: 8\n", "Loss: 0.09928972274065018, Epoch: 8\n", "Loss: 0.08620091527700424, Epoch: 8\n", "Loss: 0.08966758847236633, Epoch: 8\n", "Loss: 0.09183480590581894, Epoch: 8\n", "Loss: 0.08743667602539062, Epoch: 8\n", "Loss: 0.08666228502988815, Epoch: 8\n", "Loss: 0.09077548235654831, Epoch: 8\n", "Loss: 0.09201730787754059, Epoch: 8\n", "Loss: 0.09336574375629425, Epoch: 8\n", "Loss: 0.0882939025759697, Epoch: 8\n", "Loss: 0.08134134858846664, Epoch: 8\n", "Loss: 0.09802527725696564, Epoch: 8\n", "Loss: 0.073699451982975, Epoch: 8\n", "Loss: 0.1019870713353157, Epoch: 8\n", "Loss: 0.06060896813869476, Epoch: 8\n", "Loss: 0.08239734917879105, Epoch: 8\n", "Loss: 0.0757080689072609, Epoch: 8\n", "Loss: 0.09910555928945541, Epoch: 8\n", "Loss: 0.09245341271162033, Epoch: 8\n", "Loss: 0.08715236186981201, Epoch: 8\n", "Loss: 0.08551301062107086, Epoch: 8\n", "Loss: 0.08673674613237381, Epoch: 8\n", "Loss: 0.09225232154130936, Epoch: 8\n", "Loss: 0.09661070257425308, Epoch: 8\n", "Loss: 0.0856349915266037, Epoch: 8\n", "Loss: 0.09376442432403564, Epoch: 9\n", "Loss: 0.074952132999897, Epoch: 9\n", "Loss: 0.0973762795329094, Epoch: 9\n", "Loss: 0.08727239817380905, Epoch: 9\n", "Loss: 0.09685800969600677, Epoch: 9\n", "Loss: 0.0946836918592453, Epoch: 9\n", "Loss: 0.07464674860239029, Epoch: 9\n", "Loss: 0.1091916412115097, Epoch: 9\n", "Loss: 0.08540654927492142, Epoch: 9\n", "Loss: 0.07873443514108658, Epoch: 9\n", "Loss: 0.09215348213911057, Epoch: 9\n", "Loss: 0.09173493087291718, Epoch: 9\n", "Loss: 0.0822596400976181, Epoch: 9\n", "Loss: 0.10001017898321152, Epoch: 9\n", "Loss: 0.09149155765771866, Epoch: 9\n", "Loss: 0.09980180859565735, Epoch: 9\n", "Loss: 0.07469983398914337, Epoch: 9\n", "Loss: 0.10162779688835144, Epoch: 9\n", "Loss: 0.08965972810983658, Epoch: 9\n", "Loss: 0.09272082149982452, Epoch: 9\n", "Loss: 0.09031746536493301, Epoch: 9\n", "Loss: 0.08489258587360382, Epoch: 9\n", "Loss: 0.08522868901491165, Epoch: 9\n", "Loss: 0.09051043540239334, Epoch: 9\n", "Loss: 0.09652537107467651, Epoch: 9\n", "Loss: 0.07605654001235962, Epoch: 9\n", "Loss: 0.08628328889608383, Epoch: 9\n", "Loss: 0.08136139810085297, Epoch: 9\n", "Loss: 0.07680477946996689, Epoch: 9\n", "Loss: 0.08961768448352814, Epoch: 9\n", "Loss: 0.08857917040586472, Epoch: 9\n", "Loss: 0.08441700041294098, Epoch: 9\n", "Loss: 0.09196633845567703, Epoch: 9\n", "Loss: 0.08878505229949951, Epoch: 9\n", "Loss: 0.09737300127744675, Epoch: 9\n", "Loss: 0.10665343701839447, Epoch: 9\n", "Loss: 0.0984821617603302, Epoch: 9\n", "Loss: 0.07351763546466827, Epoch: 9\n", "Loss: 0.08747286349534988, Epoch: 9\n", "Loss: 0.08971679955720901, Epoch: 9\n", "Loss: 0.08926985412836075, Epoch: 9\n", "Loss: 0.08825411647558212, Epoch: 9\n", "Loss: 0.09101610630750656, Epoch: 9\n", "Loss: 0.07893285900354385, Epoch: 9\n", "Loss: 0.07997577637434006, Epoch: 9\n", "Loss: 0.08593187481164932, Epoch: 9\n", "Loss: 0.09999746084213257, Epoch: 9\n", "Loss: 0.0822121724486351, Epoch: 9\n", "Loss: 0.08718079328536987, Epoch: 9\n", "Loss: 0.08673427253961563, Epoch: 9\n", "Loss: 0.08414871990680695, Epoch: 9\n", "Loss: 0.08811739832162857, Epoch: 9\n", "Loss: 0.0737643912434578, Epoch: 9\n", "Loss: 0.09219112247228622, Epoch: 9\n", "Loss: 0.09499291330575943, Epoch: 9\n", "Loss: 0.07242050766944885, Epoch: 9\n", "Loss: 0.10049676150083542, Epoch: 9\n", "Loss: 0.08951854705810547, Epoch: 9\n", "Loss: 0.0999721810221672, Epoch: 9\n", "Loss: 0.09247854351997375, Epoch: 9\n", "Loss: 0.07728950679302216, Epoch: 9\n", "Loss: 0.06448691338300705, Epoch: 9\n", "Loss: 0.07841381430625916, Epoch: 9\n", "Loss: 0.09069116413593292, Epoch: 9\n", "Loss: 0.07123135775327682, Epoch: 9\n", "Loss: 0.09597131609916687, Epoch: 9\n", "Loss: 0.09440863877534866, Epoch: 9\n", "Loss: 0.10218434035778046, Epoch: 9\n", "Loss: 0.09929265826940536, Epoch: 9\n", "Loss: 0.07066919654607773, Epoch: 9\n", "Loss: 0.11114039272069931, Epoch: 9\n", "Loss: 0.1049976646900177, Epoch: 9\n", "Loss: 0.07026403397321701, Epoch: 9\n", "Loss: 0.09360364079475403, Epoch: 9\n", "Loss: 0.08021758496761322, Epoch: 9\n", "Loss: 0.13227015733718872, Epoch: 9\n", "Loss: 0.0857333168387413, Epoch: 9\n", "Loss: 0.08923828601837158, Epoch: 9\n", "Loss: 0.06564822793006897, Epoch: 9\n", "Loss: 0.08624057471752167, Epoch: 9\n", "Loss: 0.0866137221455574, Epoch: 9\n", "Loss: 0.06691738963127136, Epoch: 9\n", "Loss: 0.0908471941947937, Epoch: 9\n", "Loss: 0.07081125676631927, Epoch: 9\n", "Loss: 0.07842528074979782, Epoch: 9\n", "Loss: 0.11552178114652634, Epoch: 9\n", "Loss: 0.09867746382951736, Epoch: 9\n", "Loss: 0.07390105724334717, Epoch: 9\n", "Loss: 0.08417465537786484, Epoch: 9\n", "Loss: 0.0967276319861412, Epoch: 9\n", "Loss: 0.0920858383178711, Epoch: 9\n", "Loss: 0.07572290301322937, Epoch: 9\n", "Loss: 0.08656008541584015, Epoch: 9\n", "Loss: 0.0918826162815094, Epoch: 9\n", "Loss: 0.10554063320159912, Epoch: 9\n", "Loss: 0.08934647589921951, Epoch: 9\n", "Loss: 0.09057360887527466, Epoch: 9\n", "Loss: 0.09137235581874847, Epoch: 9\n", "Loss: 0.08764832466840744, Epoch: 9\n", "Loss: 0.07917024940252304, Epoch: 9\n", "Loss: 0.0912671834230423, Epoch: 9\n", "Loss: 0.08849883079528809, Epoch: 9\n", "Loss: 0.09417363256216049, Epoch: 9\n", "Loss: 0.04771730676293373, Epoch: 9\n", "Loss: 0.10762224346399307, Epoch: 9\n", "Loss: 0.08276030421257019, Epoch: 9\n", "Loss: 0.052407264709472656, Epoch: 9\n", "Loss: 0.06693300604820251, Epoch: 9\n", "Loss: 0.07219705730676651, Epoch: 9\n", "Loss: 0.10493184626102448, Epoch: 9\n", "Loss: 0.1038055345416069, Epoch: 9\n", "Loss: 0.08958692103624344, Epoch: 9\n", "Loss: 0.08443622291088104, Epoch: 9\n", "Loss: 0.0833052545785904, Epoch: 9\n", "Loss: 0.10586792975664139, Epoch: 9\n", "Loss: 0.0817166417837143, Epoch: 9\n", "Loss: 0.08074997365474701, Epoch: 9\n", "Loss: 0.09111910313367844, Epoch: 9\n", "Loss: 0.07266872376203537, Epoch: 9\n", "Loss: 0.09336880594491959, Epoch: 9\n", "Loss: 0.08866416662931442, Epoch: 9\n", "Loss: 0.08915765583515167, Epoch: 9\n", "Loss: 0.09368506819009781, Epoch: 9\n", "Loss: 0.08114416897296906, Epoch: 9\n", "Loss: 0.10218866169452667, Epoch: 9\n", "Loss: 0.09324317425489426, Epoch: 9\n", "Loss: 0.06947305798530579, Epoch: 9\n", "Loss: 0.07948879152536392, Epoch: 9\n", "Loss: 0.08767526596784592, Epoch: 9\n", "Loss: 0.08044221997261047, Epoch: 9\n", "Loss: 0.09291483461856842, Epoch: 9\n", "Loss: 0.09230954945087433, Epoch: 9\n", "Loss: 0.08549445867538452, Epoch: 9\n", "Loss: 0.09530370682477951, Epoch: 9\n", "Loss: 0.08001953363418579, Epoch: 9\n", "Loss: 0.08556060492992401, Epoch: 9\n", "Loss: 0.08849553018808365, Epoch: 9\n", "Loss: 0.08635181933641434, Epoch: 9\n", "Loss: 0.07598816603422165, Epoch: 9\n", "Loss: 0.08768333494663239, Epoch: 9\n", "Loss: 0.08585868775844574, Epoch: 9\n", "Loss: 0.08838070183992386, Epoch: 9\n", "Loss: 0.08511821925640106, Epoch: 9\n", "Loss: 0.07946973294019699, Epoch: 9\n", "Loss: 0.09176114201545715, Epoch: 9\n", "Loss: 0.06730469316244125, Epoch: 9\n", "Loss: 0.0944196805357933, Epoch: 9\n", "Loss: 0.05201512947678566, Epoch: 9\n", "Loss: 0.07947589457035065, Epoch: 9\n", "Loss: 0.07191232591867447, Epoch: 9\n", "Loss: 0.09604261815547943, Epoch: 9\n", "Loss: 0.09151257574558258, Epoch: 9\n", "Loss: 0.08036286383867264, Epoch: 9\n", "Loss: 0.08668318390846252, Epoch: 9\n", "Loss: 0.08424527198076248, Epoch: 9\n", "Loss: 0.08812244981527328, Epoch: 9\n", "Loss: 0.08832524716854095, Epoch: 9\n", "Loss: 0.08474713563919067, Epoch: 9\n" ] } ], "source": [ "for epoch in range(10):\n", " net.zero_grad()\n", " for j, data in enumerate(dataset):\n", " image, context, prediction = data\n", " image = image.unsqueeze(0).cuda()\n", " context = context.unsqueeze(0).cuda()\n", " prediction = prediction.cuda()\n", " output = net(image, context)\n", " output = output.squeeze(0)\n", " prediction = torch.argmax(prediction, 1)\n", " loss = criterion(output, prediction)\n", " loss.backward()\n", " if j%10 == 0:\n", " optimizer.step()\n", " print('Loss: {}, Epoch: {}'.format(loss.data, epoch))\n", " net.zero_grad()\n", "\n", "torch.save(net.state_dict(), './pix2code.weights')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Testing" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Pix2Code(\n", " (image_encoder): ImageEncoder(\n", " (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))\n", " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))\n", " (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n", " (conv4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n", " (conv5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))\n", " (conv6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))\n", " (fc1): Linear(in_features=100352, out_features=1024, bias=True)\n", " (fc2): Linear(in_features=1024, out_features=1024, bias=True)\n", " )\n", " (context_encoder): ContextEncoder(\n", " (rnn): RNN(19, 128, num_layers=2, batch_first=True)\n", " )\n", " (decoder): Decoder(\n", " (rnn): RNN(1152, 512, num_layers=2, batch_first=True)\n", " (l1): Linear(in_features=512, out_features=19, bias=True)\n", " )\n", ")" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net = Pix2Code()\n", "net.load_state_dict(torch.load('./pix2code.weights'))\n", "net.cuda().eval()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "test_data = UIDataset('./dataset/evaluation', 'voc.pkl')\n", "vocab = Vocabulary('voc.pkl')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAEACAIAAADTED8xAAA6PUlEQVR4nO292XNdSX7nl3n2/dwd92JfCO4sVlXX2lWt0ahbrXVa40c79OSYB/vV/4Ye7Jlw2FboxWF5QhEe2WHFqGPG0y2pu9TVVaydO0GCJJYL3H07+5KLHy4LRJEACQIgANbNz9O9Z/2dPPk9mfnL/GVCSilgMEYV7rgNYDCOEyYAxkjDBMAYaZgAGCON8PSmWq328ccfp2n6oteCEF64cOHixYsQwno/+M1SO8X7bGEXDPF3zo3JIr+/0xmMPfKkACilf//3f/9Xf/VXhJB9XO699977i7/4C1VV/9M3m//LL5f37WHK6uJ0QT8zbu/zfAZjb+xQBWq32/vL/QCAXq+XJAkAoBegg/hXo5T4Cd7/+QzG3mBtAMZIwwTAGGmeKQAIAYDf+bEFLwKOB/D5+uE4qMuCyHMAAIGHOx4j8hzPweExDMZR8sw8p+SBYgMIgWQD2QCcADgBQB6IOph8H1TeAGYFcAKAHIA752wAwGRO/29/d/7CpCUJ3NvzOUXkJYETeW74QxI4nuf+6PXxi1OZt+ezArf7hRiMl8AObtDH8DKAEtAKQLYAJwJZBzgFKAROHfRXgDkJMrMgtwiiHogHoL8Kdmr2SgJf64W6Ir05n/vJxbHFMVMQeEKpLvP9AHl+9Ot73dminjfkjCoULOW/XKtF6T6b4AzGi/JMAQAAKAZqHggS4AQQ+4ADACcAcoASACBQ8wBCICiAJDvm/iGSwNm6yEPQ6IW8KJytmF6UXlsfzBQNlQdFSylacpqilW44ZsgJYrmfcXQ8rw1QWAS8CBIXEAL0IpAsoGRA6oHieWCVQdgBsQPCLkiD3a6BCVUlwZDFMxUra8gKz210g2ondEOkSTzPQVMW/uOXGykBEg8jTFTW+cU4Qp5ZAvgNgAJgTYHBCsApsGdB1AFqAUQDsPEZ4CUQu0CQgDEOwt5u16h2vf90NU0xkXgOEYoJ5TkYpyRB5Ga1DwEdBGmK6d2aSwHFhEasBGAcIc8UAElB2AMoBigAlILuMqAIhH1AMYj6j45BIcAIoHC3a6SIbPZ23hskaOt3P0he3HgG46A8rw0AKEj9Rz9JAgAAOH7ykK0DGIxXDeZ6Z4w0TwoAQlgsFjlun8LI5XKSJAEA8rrAHcClr4qcIbHWMOOls0MV6Gc/+1mpVNrfcOhz584pigIA+KPXJ7KGsu/h0HldnC0Z+zuXwdg7kAXFM0YZ1gZgjDRMAIyRhgmAMdIwATBGGiYAxkjzHTfovkOBGYxXiO3dXI8FgBDah++fwXjlEEVREB7l/McCwBizPgHGKIAx3hLA47KA5X7GiLA9q7NGMGOkYQJgjDRMAIyRhgmAMdIwATBGGiYAxkjDBMAYaZgAGCMNEwBjpNlZANW1lZW16tbfNIl+9U+/8oPoqKxiMI6InecFareaTkxa9c18sSgKQqO+ce3GzWKxyIvSmdOnIJvCmfF9YdeJsR7cW0qiyM7kHizfnZqZoQR/9NGvAeArlYptsfkaGN8Tdm0DeJ7/7g8//OCH7wJA5uYWivmc6/kcz6UI7XYKg/HKsXMJACFnWcZvPvonw7A0zVhZeaipWiGXlVVdkeUjNpHBeHnsLIC5hcXK1PRmdSOTzamqEoaBqupR4EUp0TXliE1kMF4ejyfGiqKIhUQyRgGO44bzFwLWD8AYcZgAGCMNEwBjpGECYIw0OwsgiiKEcOD7wyZy4PtJmsTxk6sY+X5AKQiDYNh6juOoPxg8fbU0TXx/r6vIpGn69I0OnTRJet2ut2erGN9XdnaDPnx4HxEucnqL5y7Isnjvzm3ZNKM4tXRNkmRd1wElkiStrq5PTk226jVJVVVN6zYbm83Owvysoiiapg/6PUGUoihcfbgcxOmlS5ct0wijSBLENE1dz+V43jJNjuMkSUJp4ocRoKTT7TpecO70ojNweJ7DhNq2RTCGvEAwCsNQUeQwjCghsqpqquq6LgBQEHhVVR3H4XgepakgCIQCHgJEiK5pqqp6rhvFMaGkWCgihAa9/jdffzM1Nz85MQYhr2maLEtHnPSMk8DObtCvvvz83vJDHEeFyvjCwsLS9a8jKIZhEPm+aVpJFBXyhWKhcPfeg+JYwe93EgrsbHa8VPzq66uKqnhB+Ps//smVT34bI4QxEjiACJUEMVco9jrtJEUiL4ZRKMmywPMI47feebvT2KhuNgCAjucqiiqJIsHY0PVOt6epMqVU1c0kCgEAaYpQEkGO50WhVCo9XFnVVQVAvjRWvH/vrm7aCCFREnkIAQAUE0TI7/34x9e/+Xpzc4NA7ic//kl1faXddlxnICuKpsmdXv8P/+iPi/nscb0DxhHzfDeorCjl8Yogit12S9O0JI4xwrpuLJ49O1Ypp2lCKOk7A0EQ260Wz/O6YaqKbFmZUqlULpesbC6XtSenZjL5nCQKhmlUxsdLpXyr2bRMs9frirK6eOr07NQMJajX7YiCCCFvGrpuWqZpTk5NRUkyPj6OEb5w8Xyv10OEJlEgiFKxkKMAFnPZYqlkm3qz0ZqamS0V85qut1utqelplKaSwAPIWaYhKZptmVEUchAqmjFWLOqq6oUxQYgAkM+XDF3hRSmJo61pkhijxs4vXpakfq9/9tx5RJBlWa+9/uZmqy2rmmWZUeCfOXcep0mn25tbmIMQx567sl5NET5/VucgabW7i6fPCqJQLBV4kYMYW5bZaXfypbHyGDczO2tns6Zpy5IMADUs3Q/8bC4bB44fxgSnfhB12+0zi6fu3llKUpQp5C6+dknXLc/pmla2Wt2YmijfuXWL48XX3/pBqYx104zDwE5Ru9O1LZNSQDDp9juImijyNr3wtcuXCaWyCO83W/lSyTQMY2Y2kw8JJmkSxghfvnw5iSMAzCNOesZJYOcqEKWUUjoc9gzhzssoDQ+glNy6cbNWr5cnpi9dOFvbrFqZnKaqW2d9exiNo/DOvfuXL116ejD11sEEo5X1jdmZaQ7CTrsDIJ/PZwAAGKU3bty6cOkiz3EQwvtLd6Aoz8/PbV3Bcwar1c0L588N77i6tlYulz/75Le50vjF82cghJSQ1ZWHE1MzoiiAR3OD0ds3bk7MzFmmzgZ4jxTbq0CHMBRiSy3PzUZbotrHMdu3U0oB+M5Rw6d4vIlSACEhZLtVT195L/Ywvn9sF8Ah1H33kvW3jtz3Mdu3P33Mk1sgBN+dBXtPZzFGD9YRxhhpdi4BKKApTvc9XzQEUOAFDjJ1MU46Owvgbvvuz+//HNF9Bn9BAN+tvPujmR9BwOoYjBPNzgK43r6+FC7tv35EgdAS3p98X+TFfVvGYBwBOwsgwQlC6CCf7wQnbMUNxslnlzYApZhgAAHYysNDMVAA6J5aziy4jPFKsIsblAJCCAf5CX1C5MQYx92oHeKook3O6ZNfdb6K8FMDNrfUAgFgCy4xXhF2LQEIJhTC+eypaXOiE3aC0LvVuTmfWVAg/2bhzSiNEIUJDizJTgkWOVj1Nywpo/Li1c41RAklTACMV4BnCQAAggnkAMyrxYpSrvubLa/1Ruki1UiAk4xg65Jyp3t32qxwAExlpyVBXe7c5gGXYMQEwHgl2Lk6TynFGGOM/cTvhr0UhU2/m5EtlVfbftcSdZ4IuqRzkIeY80J3w9mUBNWJBjzmKaEYY0JZG4DxCrB7GwATAMDt+rUHgiRyQtEoN5x1AkWF52u+vOlsPOzf5yGlFNY6tyIc32vdoQAW1WwQR4RVgRivCM+uAoGO1xpuWe2uDn9s9FcrZqXpNfBOfp6VbVc4ZEsZjJfA7m5QjHc7p9qr7rZrC+YGZbwS7CwAHvIHzMEc4NhYS8bJZ2cBXBq7dG3zWorT/V2Ug9zb428LHIszZJx0dg6IIYQEcbBvTw6EUJEUNhCIcTJ5fkAMx3GGylbBYHz/YUP2GSMNEwBjpGECYIw0jwXAvJaMEWF7Vn8sAJ7nmQYY33sghDzPP/675Qal33JMhjEYRwH8lkd/WY5njDKsEcwYaZgAGCMNEwBjpGECYIw0TACMkebxYDhKKYtiOTK2u6KHPCMCiXG4cNzjYJXHAkAIpek+AwAYLwSEUJKk7RrAGCcJm0vviBBFURQfjdV/XAVin/8j4+nClhDCcv+RsT3xHwuAvQDGiLA9q7NGMGOkYQJgjDRMAIyRhgmAMdIwATBGGiYAxkjDBMAYaZgAGCPNzgJI06TX63meDwDAGIdhyEZJHCXOYNAfDJ7ensQx67A/XHYWwN3bN//tv/2f/vMv/hET6vR7n165srq2TghhncVHQOC7//7f//X//n/8dd9xt7L7MFz7y8+udPvO8Zr3PWOXZVKT5M133p+fGr99+1ZtY4MA2Gk1P/7o1+Wp2Z/++HfZ5BEvlcDzOFF567Wz177+cvnByuzMzPh45datG512TxHFuTMXjtvA7xW7rA9A8GeffJz4l4MwSJMUU/ogiccnJ5vNBqWA5f+XimlnTEX49MrnF8+fCX1vaWlpY6M6NT1dXd8wteJxW/d9Y+cqEMcLf/jH/+qnP/1xo7aZyeYgAIQQ3TAFgc14/tIZ9HvjM6fmpsevXr9JCC0UCg9X1lAST8+dymas47bu+8bOAlBU1bItw7RnpqcnJyd0XSuWip9f+bTHKqAvH03Tb179arVaKxXy/cFAVbSMbbXbrbt3lwZ+IAhPRtIwDsLO6wMglFIARUGIo4jj+S+u/HZto9lptyqTM//6Z3/CqkAHZ3tMBgAgTdPtfjbPdTGlsiiGUfTg3tLKZuuPfvp7cZJKkqQqCscx5/WB2L4+wM4CeIJWs8HxQr1er0xM5jL2EZn5vebZAthOq9mAglTIZY/KtO8/LywAxqGzdwEwDp3tAmCFKWOkYQJgjDRMAIyRZmcBeK7bbLXjOPZ9f/v2OI78INi+JU2SjY1NSmkUhc8YKJGmaYrQHm0iGDvud/ytnuM84/RBf4Dx81svlNJev58kjx6q2ahXqxt7NOnocZxBq93dviVNE/TMNOwPBvjbVhxGKEkeNyqcfp8Qsv0thIEfx/HWvdAzZyVKksTzgt32Ukp7vf6Ob9/33GartduJg8GAkCdP830/jpNnGHO47NyxNRj0Pv3i67mpyU6vd/78BV1TFVWt1zYGg0Gr3blw/kK/P5hfmK+urbZazXqz/f6773722ZULly7Lkug4zuzsbL/Xp4A4jmsYhiQrvXa95waL8/PtTqdQLBGU1OoNWZbTNJ2anvGcvuuHksBphqUpUhiGrW4XEGKYJqRQ0/WV5btmLp8mCUIol8/5ns9BEIZxNpfL2Pbtm7cLY6UkiQqFIsapLEn1RlOW5SiKNU1JEkwJ4kVRU5Svrl6bn5tVZCkIont3b4dB8sMffYCSpFSuZOyT0scUR3EQhveWl5IUQjDvup4sib1+r9NuSbJqmhbHc1EYWnbG99xCoTAYDLLZLABg6d7djJXNZjOiKDTrtU7fy2ZMUZR0Xbt3+/bEzMyg14W8mMvneAidfl+QJFnVCEobjQ3DLkxPTqI06XS7Gdv2/CCbzQAA+r0u4IReuw2gYFiaoWtJiifGK3Ece54nSRIhRFPk3/72s4Uzi2kSWabluq4sy5QCSnCzUYsxHS+XLdvGCOuaWqvVNUP3XBdCjuMgITSbzfb7PdO0eB4OBgPX8TCFY6Uc4IRT83Mve9zNzl6g27dvXPn8G03mBVFOCZmfnSmPjd28dbvbbsuqksbJWLly5vTitWtXG426quqmaQ6cAUZY0w3bti9fvvzrf/wFgsKg15MEfu70ubGccXvpHgSg1midOXv+zPz0L/7xn3TDTKL4wmuvD9r1VqeLKLh44aIA0b17y4PA12QFo9QPkrFSqd3YJIIoSwIhRNfUbqvD8wIviKosT03PfPP1N3rGhIRYmSyhaCyXb3W7jusLAo8QTqLYss0Epb7nq7ohQAo5DkDODwIRchwHNcN67fLlmenJl5rQT/AML1Cz3lhZq1ZrVd8JS6Vsr90mvKirkus5Ei9wguQGoSjwAi/JotDrdeM4VjWNF0SeB54flvL5drutqjon8JggURAJoTSNNTsjAOC6gxSTbDaniEJvMDAtu1avW6bGc3yKKARUFPlutxvHiabrvCCOFQuNVlvXNGfgmKYWxIltGq+9dnnpzp12qxVEcb6Ynxwr3bp1VzY0QRAj39c01fVD29CiKMSIGJbhuA7BNJsv8IB0uo6oSgQTURASnHKA5yGVJDlNkjCKNV1No0TWFD+MCrnc7/zOvzB09dAT//leIEppsVwxTd0wDYJQp91OMRpOZmYYBiZY13VKiSwrhm5MTU66nj9WHuM4gBDSdQ1AQClAaWqapigKGBNN03LZjOv5pmEaht7r9yYmJscrJUIQQlgQREWWeEGQJEHXdc8PJFGQZQVCqKiq5w1s2yYIlyuV2dnpKAoJIfl8YX5mWpbEdrcDIMQInzt3zvM9SZIazVYQJYV8XpJEyHE8x6mqJok8wbhQKlm2jjBWFcXOZC3b5jkoKyrPn6i2EO31OkmSGLrR6w0M00jTVJJkwzAydrZSKhcKpZm5OUWRdcNIESqNjUFIg8A3LTtjmVEcRVFkWJZl2TNzc2OVMY4XZFmSZdW2zGKxmEQRgJzA8xBytm0ZmhqGkWVnMEpESdINnRJSKpcJIUEYpmlaqVR0w+Agp+sahFCURMsyB84gl89FUSTwQqvdzeaKlmVpui6K0uTEpJ3JmZYWRLEsyxQTQgjGWNM0x/Hm50/phjE9M1OZmKAATkxNCiKn6zrHcZjgqckJUzcmJiYFUZQVBb98v/zOVSCMydTEuCrB9fXNixcutjvtfK7g9vqWbbv9/htvvNnr9TO5vK6u9QRhbn6+MDaWxmEuV4SUOJ4n8Pyl19+Io7jX72ezWU3TDdMqjVUmpmb63U6+WIzcHuS48fGJJE4sW6926rqdLakSx/H5fOHSpYtr66u1zc2Z+VPjlfJGtXpq4VQQeKKsqIqkKlqz2Z6cmLJMvTBWWt/YfO+DDyhOimNlzTAoJetr6wJClJJGs5XN2BGAgsAhCj/48EMKOEmEnXb72vWbb771dhx4YqWcRJEgSi87ofeOZZmqIpUqZwzVwDjhed4Z9G9cvzExPVPM53O5XDFJFEUWAazVGh9++IHr+flcJkmQosgbGxulsTGUpoZhdDrdYqlEATVVzQ9CUeANywo89+z5CzwHdcPMFfJBnJQr45IoIEJn52ZzucLSnZsLp0+rqp69dDFN0l6vywti4AfnL55febgs8rDd7omilM8Xz549MzU1o6jK2ura+MRkkoS37ty7cPFiLmPbubwkAMiJum6psrhZq49Vxg1dnZ6a1DSjFBUUWQEA2LZpW5nxUv7mjduZXOHMmdOmbeUyed3U16trmma8jM//E+xcBVpfWymOjSvyc7IFIbjRaFQq4wc0ol6rF4rFrVEuge+trK62Wu03f/CWLPIPVtZOn17kXrAuuLmx/mBlXVMkUdbOnj7V7ffGSmPDXc1Gbenu/bffeee5D/jyeKGOsDDwP//iy/MXLhXyj/uDHccBgLMs41DsIQTfu3dvbm4hTaIUk4y9Q39/v9ddXr7HS9qpuZlmq7OwMLd9bxSGXhjtr8fadV2MaWbbUL/Nen2sVOJfzqAP1hN8/LCe4GOE9QQzGI9gAmCMNDs3gmMU1906pvtZskERlLJZ5iCT1n6hmDjrFIX7PB1CTi9DOXOYJn1/2VkAn9c+/7sHf0fgfpoEGqf9m0v/ZsaeOZhhowvu30df/Y8A71cAAMDSW9Jr/x3gWPje89k5japutUd6YF99cB72umGXCWDf0KAJohYE+5+Bg7qrFCeQCWAP7JxGhBCCyP4EAChbbOZgUAJICg4gAED2OuyKscusEIASQgDc9haGv7+7hQIKKXy0d9vJjINBAUHDdEQE8pDCYXFAAQGQe5Tiz4Sw9fb2ym7TolBMMABwzl4IU6+iT6w6D01RXXU3xs1pQpJJc3K5d0+Eckkbq/vrnbgPKaDwkRIoE8EBoBRQggAgVMqByZ+QzjXgVTmtAKQ8yCzi9f/Cxd2hHobCoBRwj5IdcpACAOi+vBejyS7VRAoIJgDAol7sB/x8djEv5afs8Y/WfqUqOYXA84WLC9aCJCiGoKwO1r3EMSX7y8Znm0EDgoNUXxkAAApICigRrCnoL8PixaT0lhBWAU5gsklnfpK2bwJ7EUIIe3eoMctJFuzfJoJBJYNWfwlxDJgA9swuJQClQwFQQikhtpIzOakeNm0lQznpbOGUyHGttG9Bq546HE+nsjNu3M8oVtWtwX02HRjfQiklCFAMIEDOmlx8Uw4bFAWEYBLWRfsMN/UjEDagkqfquxBimqZw/G0qmgjytPUJ8TzA2mB7ZtfRoBhjjDFHhNP2YsdvIEr8qB+niRc7H6//ph302kE3QWmUhBwVISFtv4NQijFmK34eGAoIAgTRyCHFH2KcpLxFAaRAQLk3ibsGcZACmaAIQooiH0MRkpS6azTo0DQEQ/Ew9sazSwBwtf71tDVRdaoSL2OShCimAFBKOn4vxtGyeB9QLPIqB3A/dlMcE0z213vA+A4kBRTjznWR4/DGOtTK2K9RCoVgPR2sAsmm039C4gaq/gbwqjjxYbzxEUUhp4+RsA0IZgLYO7u6ioeuzI7X7HjNp/e6kQsA6ILu07uYAA7KoyoQAiRNNz8GAABnbbiHeGsAABD3+fpHOKwTvwEApHGfRF1AEfG+jfBkAtgzu3qBCNpnPiYcqwIdkGEV6BmDQ1Pc+GLrz+N8/3gT+wbtlZ0FoEv6vqvyEpQUQTmYVaONoFEA6QE6s6CgAcimEN0TOwvgw5kPUYoSvJ/g/KyaXcgtHMyqkYYvXASX/wcStvd3OgRQmPwQCicoxu0ks2tAzL6HM0AI2Qoaz+XZATHD9WD2fXH2Cp7N9oCYXRvBbAriY4Tl4COD5XLGSMMEwBhpmAAYI81jAbBK/1HyRBWf1fiPku1Z/XEjmOd5FshyNEAIef47fnqe53meZx2IR8P2xIfbE529gCPj6U8+S/wjY3viQ5bujFGG1fsZIw0TAGOkYQJgjDRMAIyRhgmAMdI87gcghCCEXhWn0NCVvuXQpZQihF6hfgye5wXhceIjhPAzl6k7UXAcJwjCljNxGD7+CuUcQRC2+sK+8w6evQjhSYMQwnHc8DVgjF+t+fWHxg9fw3D9xlclAwEAMMbDbAQAoJSmafoKfXqGSNKjeInHVaBX7hm2D5p/hXLPkCeMfxXt3/rxyhm/PauzNgBjpGECYIw0TACMkYYJgDHSMAEwRhomAMZIwwTAGGmYABgjDRMAY6TZWQC9XndlZaXRaL5ynXxD2q3mw5VVvK3DDyPkeR4A4OSPd6KUOI5zsm3clXarubKy0h84AIA4jqIoGm6nlLqOcwJTfmcB3Lp+7X/+d//unz765zCK0jRFaRrHMSEkiiKM8fD3ERu6dwLP/dv/8B/+9m//dvn+ShzHCKE4jrud1iefXkmS5MvPrnR7g+O28Vmkcfzrj34VRVGSJMNxZsNHGI72O8kpTwn54rMrf/m//a8fffxpFEUPH9y/cfsOwjhJkigMPvroIz/Y/+LHL4mdp0Z874cfNGr1d3/47t/8n39NAK/JQt8PL5w988WXX567cGl95f7b73/4g9dfO2Jb90izUZ+cPXVqdvyff/2r/+wOpucX2rUNK5e7d/fu6spqp1WXzFw+lzluM3eFUhr43t/9v/93GONiPosJlUTu5q2lcnn83LnTkqJdunDuuG3cGchxv/eT3291OoYq/eVf/mW5Ur59+3az0WjVNhTd2FhbWV3f+PM//3PL1I/b0sfsXAIMByp22+3yxLSmCGGUTpRLN67fuHz58q2b15M0Xbp79+SVZo9I4sQwDE3VBr3+9Ox8u153XF/ghYXFM8W8VSgWy8X8cdu4KxijMIwoIZpuDPr9OAzur6w5jnvp4gXXda5evQrhiW62bayvAUEpZE2EMM/zr7/++vLdpWyxjJNIMzOVsYLjusdt43d4XmpSCiiQZdU0DEppEAaUUsOwxiuVE7sWnmEaa6sr//Hvf67omm3bqqrNzc8uLy+jFAMIOY4jJ1a7AHTbrV/+w68cx93YrAk8H8VxmiQpQmEYFgqF6sZmPn9y1UsJ+fST31bK44IoF3KZhw/XbNsWeGE4ZFRRVN3QT9oSursIAMJiaaxcqTRq69Vao1gq5vL5ubmZq99cnZiaSeJI142Tmv9BqVzRJLixWS8Ui7lctlQqdrrd6dk5z+31nHBmZrrXP7ltAN0w261acWxCU+QwDDEQLl86Lwri1avX8oXihfPnMxnruG3cFYxRGMVXPvnnr69ec1xvcXExm82ePXe+12lKqjk/O5PL5RVZPm4zv8Ou6wNgjDmOT9P4o19/9M577xu6duWTjz/97Kvf+/2fnju9sD2m5riAEMqyPDTjifn1h61GAICqqsMVLzmOp5SAb2cFO3bjAQCyLA+NGfoVtrYnSczxAiWEEMJxPMfBX/3y/7txZ/lP/tWfzc1MPjGl3HGxtb4BISSO461cNIzs4TgOYzyMGqMUEIIh5CB8NO37sc8DuX19gF0FsIXv+6qqcRyMorDd7oyVy6Kw66oCR8kzBPBKsJsAnibwvV7fKZfLPH/8uh2ymwBeCfa0QMYWuv6oza4o6uTk5Eu0i7ELmm5ounHcVnw/OSlfFAbjWGACYIw0hyOAPVYBnz7s0YadIsRfXq3y5VVYX0aE+N6vuZ9bv/yk2KP9xzVL+bMEQCkd9Hrtdgd8+xhPAwCo1zev37qDCXli+/bnH7pibt68tbnZCAN/s1ajlKI0rdUalNKN9bW19erw+OrayscffxoEwdr6arvd7vUGW1dw+v1Wq/XEXXaznFIaRUGr3X7SeEJq9Rr51trtxz/xe2vL0/d9Kg0enbixvra2vhFF4fC+ruvUmy+22mkcBpsbm8OLDvr9Vqt7986ddrfvOs5g8Nh72+t2om1NT0opRunqyvpg0HdcN/Dceq0+NK7dbqcp2m5uu9Pq9R+lanWjGieH7DwYpvLSrZu9vuP73o0bNz777IstGwAAvu8NXG978hGMV9bWH70fjFdX1obZqVatBkH49Bs5RHb2Aq2uPPT94OHKCooTyzZ100rSpNVql4r5RqMFACUUTE9OtjudqemZer3m+KGpyqZpDhxX5GGC0JnTpzvNFqbAcx3I851OW1G1FBFNUly/b2dzlm70+z3LzjlO13VcUVJESZybm7321Ze9gTcxNSkrogDl0HMkXVNEQZDkjfVV07LjJBUlEadpJl/4yY9/T+B58F0vkOv0P/n0M0PXB6575uxZSNJr129yHBcEYaFQnBivNFoNz/V03QjDcGpikuf5pXv3stlsuTxWKBSufv2lnilMlQufXvni7PmLSRLevX1LkhVeEGVJtCyr0WhKAjfwgnK5kiSRZWVyuWw2a3/x+Ree54myqqmybuiSKPm+X6xMXjhzasek39EL5A36X3zxpaCooiiEXjg+OXXr5teCpOqafu7cmTCI+k5fkuSvv/gsXyrphv3a+bO9/qDVana6XUB5O6NhQnCCMCEAgsrE9L2lmwiTNE0LhYLnehMTE0t372YyWUkUEUayqr37zruG9sILm+/oBSIEX/3mGz+IBoN+p9XM5guGaTcaGyIvl8ZKkiS2Wq2JyemNtQcJ4ThAdE3DlJaKBULp+kYtb1mNZkOUFZET88Vsq9UMPVfSzFzG6rve+bNn+90uBUAzjDNnzgoHcIg93wsUhcE331wtlccj1+t3uzdu3imWxyqVchAmkiyJosBBeP36VUJBt9e3bKvbbbfi2LRs1xnMzi2EvtNqtfqd9nqteebsWQ4gO2NzELidQTGXBxAv3bm9uLA4OTl54/rtyZkJz40AhDNTU0t3lnK5vGXlVE3+9LPP3nzzXYSQgtOvb94UJfnSxQsbmxvlcmVlda0yMQ7pzuM6G436wwcPZV1/49IFz/eT0Dd0tee4mazt+87SkuP67tTEzEZ1bWpm9t7de6ZlcaLkOD3HcVRFkRQtSVJN1xVF6Xd71dpGuTJBMOo7zsrKBuTFMwvzzUY9Z9mN6lrP9w3dKI2VeThlmkYQxu1WE6M0X8y36k3Nsscmpl/w7dCbt26dPnv2mzt3coWxqZnZYql449YSR8HM9HSn3ek5Pdf1C6UiB2h9Y2N+aqLVbK48XCmVix//9spbb78ZB54XpN12a+H0qVw+V65U2u2Ormuu088XSnfvLk1NTTcb9Vq/J8lqoTT2guY9C47j2426G6Ew8DKmJSt6r91UDFuFYPne3amZ2Xan43m+phsCDzJWZnXloSBJtY1Njuda7Y6weNY0MxyE165enT29UMhlA8/VZLnv+AQljWaj12yubmwunDp77tyhjYbaWQCmaciqFnoDzw8mJiumF6iKvLm5GQVBsVi0LIvnYdbJcoIMCDZMywtCTlXTOIYc1+t14ihAq2sUE0UWm426LIulsSLBuOcE62trVtaamZl23cHKSsoLYqvZ0nUd4aTebAJAcvn85sZmtxeeWlwIAi/0PURxPp+XFaVerwFKNzY27FwujYJ2sHOBqChqJpsFEKysrs0vLkJKq9XN8kQla9sI6Y16K5uxm81GilCr1QIAVzc350+fFkDabvct2+p+0xZUs7q2hjCu19ajMK7Xa85gUB6fsCxbEITq5kbGMLOZDCWIimIhlxUVVZHlZqut6wYHQYLSTCbjub5t24LwYv1WEHJzs7MoRbpp2pbF87ym6blsFlLKcTDw+67nAUANw7h//4EsawDC1bXVFONOt3tq4ZTnuoASQMn09BRKU0PXIKCSLFmmgdJ4c2NTkuTNjY2B43IQSoIYReEh9koRgimAGCMOchzHR4GbYGLLMo0CSZIbtRohVJZEwzT6vd7GRlWSZUqAJEuCIJ49dzYK4mKhSCmenp5BSbJRrUqCmLEzK2trXuBFCdIVWdc10zS5wzN65yrQ+upDJ4htXX1wf+X1H7yBMe522p2BOz5WStM0n89zHEzTNAhjy9ABhClCge8LghBFEQfBw9WVmZk5RVYURfZcV1ZVVZEppUmShmGYpGkul03jJE5SwzA81zFtG6M0SZGmKgIveL4XJ6ltWXEc37t1szIzrykSJ4i+5+q6vnx3uTBWXlm+nQLxv/qzP+WfqgIhhGq1mqqqSZJks9mHy3e/unrjxz/5iWloAMAoDO4tL+cyGcf1K5WKqip+EGayGQ5CQohpmtW1lbV658KZhV6vL/C8pChhEAAwHEYENU1bXrpTKo8XiwVCiOM6vW5PM+2ZqfF6rW7ZGYSSMIwty0AIczwvy/Junf87VoEwxoHvpwgNJz+VZQVj5HnuZrWazZdMQ41TpKkqhHDQ71PIjZWKjXpd07Q4SUzTun/vjqJbpWJB4IUUJdlc3nfdFCFZlpqNzXbXWTx1qt8fiKKAEAYAWLZl2zb/4v3iO1aBKCGO6yZJEoRhxrbjKGy2WpQTBs3mqXPnAMFBGNZqm/OnTnMQVFdWJEUplEqqogAAeEGIwkhRFAAoQnjQ63z88ScXL7++MD/ne86NO0tzM7O2bXEQyrKiaeqLGryd5/cEu86AF2VVkYMgGHaEJUkCAJQk8fmXp9T/9qyDE/i+omnbFR9HEcfz7XZLkrWJ8fJze4IdZzAYuJOTE9/OIop8P9B1PY6THdORUhIEka5ru5k0NGD4+gEA/V5P1XRZll700fbeEwwACHyfF8Tn3sX3PFGSd3xNaZpQ+nhOzAOyx57gMAwAgHGc2LYFIaSUdjqdbC7Hc1ySJADsag9K02q1Wh6fUGQJAOD7vqZphzWG4sWGQpxYRmcoxAnkezMUYueyLwiCNE1d1wUABL63NXN3GAbDfJYmSfhttNswZOmlW70vCMZp+sg2z30ckocR8n0fYzysCfietz1+MgiC9FifKPB9hHC30wnCKE2TIAyDINh6kC3SNA3DCACQpkkQRgAAQrDn+wCAJI6jKAYAYIQ6nU6aphgjPwiGJ1JCAt8Pw8OPz2q3Wqtr68NUfQJC8DAqdbslAADPc4cD1Ld+bOeAucv13GeLc+dGsDPou54fhZGma4QgXpCLxUKzUQcUcBzHCULGNhuNpqJqBKFet8fL6msXT0qYUprEa+tVRZY933cdhxPVU/NTjUaztrZWmZnlAFVU1XMGtVrDztqUE0xNrdfW86UJTZFlWYmi0HEcK5fP6Gq31zcti4ew1++bhgk56LnuWGVcP1gd9JnGJ0EY3rl+TbEyG9W1sfKEqgjtbk8RJFHRxopZCrg4CiRFVRUlCIIwTsaKxTDw7iw/WFyYcwcOgZwqS616TVL1SqXsDgYDx8nlckHgxwidml8kOO11O7du3Jqem5s/taBru1b2XhRC8G9+81HPjT58/11KiaqqCGFKsSSrmionUXT1xo3Lr73eaTdaPbdSKmSzOYHnNmsblcqkosj3l5ey+UqaRIZhmqbpDPpBGHqu54fx4qm5OI4o4HAaQ8jHcZTL5xzH1XUDo9RxHF4QBZ4TZSWOAkEQA98XZVmRxDvLD37nww/F3V0ROwugP+j99tPP86YVRKEk8zECP/rg/S+//FIRxBghSZHf+sGbV658AjlR07R+z7l0+fXDSsSDM+j1rnzy2zROKpOTm/XG6dPn1x6u3Lm77Pa7HdfnIe07buA6iqqKdVWSBEiJ6zn1Rsf1A1NTRVFKEDKzHZD69UankMvFcez6QdY2UgwRii+/8fb5XVz7BycM/c8//7LdqItywwv8XrcHASWQg5hKqnL/PhwMnKxtBlGIUyypmqzIVdOam5tt1mtRELTbTV6SIy/QNQWj9Pr166LA67L84MGDOA78KMlmi+3a6tLyisRzX3zxebEyfogCAJRCniM4/ejXHxmm4biuoiiCJAGMIaAAwihOet1/FAVRkJX1h8uSrBJKIUhX1zbPnj2zWa3evrvCQRB4/k//8A8+/+3HzW4PYTpeGb+3dNuy7TgMOUH0nYFhmUkchXEKAbAsK/Q9SZLCKFUNVRT4Qb8ncIIoSZDnRfk5n6qdq0CEkGJ5XBS4FKEkScYnJw1dn52d4SiBHK/rGqU0k8n4rqsbVsbOmsYJGqsYRiFC2LIsXdfPnj1rWxbPC/l83s5k83YGckLGMi3Lnp2Zy9hZVVULhYKm69msXSiVZUXMZLOE4CgIkiTVVDWJY0FWDE0NgzAMQ5S+3GVENE13er1sNlco5G3bLhSKmirPzsxYpjk+MeF7fpwk2VzO0LVhbKGuaq4XqKpSLJU8z5ucmJIVdWpqqlQsSaKICZmdm+EABZAzdI0QQgjmOE4QxIxtUUIPt9XHcXxprFyuTMxMz+bzuWwuS9JY0w3TNB3XDYNwYmKcgyCTy2OE5hZOcxQ5AyeXzQ36XU3TKaWEklOLi0MjS8WCaZnFQiGXzcRxLEmKLAmmaaqKXCwUwsC3LAsCIInS4sLCZKVcKpVVVclmc4ZuLC4sjJcKsqza9nPih3YuAXK5PIEiSCKyXo3ioFQsyopaLpd1WV1dW63XW2cWFyVJXjx9JolCACEAJ6j1rGp6Lp8v5AtxEqaImAZJ06TbH5w/fx5j3Bl0x8ZKumFOz0yHYeT67vrKQ0VRCOXOnp4DOFlf3zh37vyg142SRFWBZei8JPe7HcO0IKCtVlt5cYfP3hFE8fyFC0mS1Fut04unr3/zTWVqampyMmNldNMIfQ9h7Az6upl5//0ZCgDkeMu0dE2LgiBj237gT05MFDL2rVs3x8YnpyUhm8ubqr5WXavW6j/68EOeg3Y2f+a07Pv+e+/PAXKo6zJBUClXLCsSIfz0yie6lSmNjbW6nUuXLk9MjHMc3+20ZuZODXodUeR93zt97jzCNJuxKuO+bVuXLr++WW+pqvLW229RQiamZ0VVwwREUfTOu+84/T60rMgfeJ7XaDTfeue9ZrM5Pz9PECoU8pTgbEJ4ES7fXZJUrTw+TinhNjdTwHPcs3xHz/ECUUru31uenpuTHnn96NWvv0oINz1e5kSpWMinSdJodiYnK4eZjntj716gdqulaLqxi2fTGfQQAblsdqed9N7SUrEykbHMwzF6G8/1AhGMNjdqE9NTT7zAKAz6rlsujbVbTUzhWKn43Htdv/p1kJB33nrzsDyJz/YCBb73ySefXnjtcj5jNTq9yUr5UG4KAEBpcvfO0vjUVCaT2fGAOAq7/X6l/KwMeSA36PD4Y49qA0flBqWUvqSHPUo36KG/tee6QV9euh2cF4sIe4IT+1Qvie/H8x79U7wq6cYCYhgjzc4lACWEhCHdh4sAAk6SuUPqb98fFNA4jTHdf/MOAqiICndMU1BRHAN8sBoRFKB4eM7NFyElaYKSg1xBERSeO7qZL3YWgH/tauf/+VuwD38fBMqZc8X/+r/hxGPTQMtv/c2tv3HT/c9AxgP+Txf+9NLYpUO0ao9QFKbX/4oOHhzoKnJWuvzfQ610SEbtFUzwz+/9/EbnxkEu8sH4B787+7vwqOZd21kAwc0b8MF9sK9qXBxG+E9/xuVyBzNs/2y4G3f8OwTu3zNLAb3VvXU8Aoi6oH2NQ86BLhI2iFflj1wAIQqvdq9uppsHuci1zrUPpz4U+T0MuzwMdqkCJQlJ0/0JAKTJfupOhwcmGCF0EAEAAFJ8TEPrCKE4BvhAtQgKOEqOYSwTpTTF6QE7Co845XcRAH3URzh0Zm050Z72pg2jNeG2jRAfc6cYpRRjPBQAJRQ+sx/kGRc5bLv2emdAEcDp1u3hd/Y9uXGY+PS7hwHIH9cMnJTQraGT++OIhyTvLgCMAQDyhddQY1M+dZq6brx6ny9VqDNA7TZv21AScastn3tNWJhJ7z9MHtwDcQwA4E7AmGpCCIHEEjNvl99qhrU1v6oJWj/qhN/9suqiWVaLEMAHzgNC6aNMBAE4PgFQAChBgKRAnyCCAUlCnIfDbwvkVWBOA0EFKMDOKiQxkGzASyTsQnOOug8AJcMjKSTgmBQwXMHAlDKmoMq8suauPe2NyCr5gmxHJF1317+z4zhSfpd+AEoJQgBC89LrcT5rXH49qjbEs6cFVU86XeS6cnlMMLTez/9ezOQ4QdQ+/BfBzKT3i1/QOD7k3vV9QAEhhAAylZ8ZoO5s9tRi9nSEEUZBLwkBiTlOTFHso3Az3Hyj/LYTdWbNeQhozd8cpG7VW8eUHOccxgQDknIkEWZ+jLt3kH0OBlVAEYh7UuEcZ06QoBtlLhDnIZe/wIM0jTzRnEh64xTKfPdrGNQApMclAEoowQRQuJBZkHllSpuKcCpxMCI4RYEtZ79pf1nWpy4VFppBryiPFZVCgsNe7KqC9Fnzs5QgSk6AAB6VABDiKJLn5gCggqJSx5dyRc6ME0wlgcftjjo1BURRLldAnIiQE7J2vFGDx10CUEAJJhhgHvIPu/cRAmNmsaxX6gFatMoxTuPE6YTu2dLpu7eWYowooKZuEUouWm803bXVwUO8fbaTo7efIkAQDto8x0FBFswFqbgIeT658zeAkwFBnJpRcA1O/4h6a5xZEQDiJIHPTGO1grxl6q1TeFz5HxBKMMZuNLDlnABBG4VjekUUBIIjAcoap617D0tq0RLtQqFSddZ5TlCpVrbG2vGAB1yEMaFHmn92dnUPBUAwTnsdXlbSbpdSiiU5HfQxJkBT441NLMkEYRwG3toKFUQsSSRNCcbkuNsAAACCCcEkSKLL+TfP5E8pQKIEiETiKM8RKALJFExCqCVkKE4ycrbptZp+24m9MAlShAkmx9kGIBiQFKpFEHWolIEkhgBBimnYBsmABg2aOCmQCU5SoBDAp1TGlEMEkLjPKyYgKSDoeEuAJIlSQpxooEENoaQT9Hgi6qIu84LMKTzA624dkVSEct1ttIMeJUAG0vCtnYgSAHzbBnA++8S79hWAHCcpQj7rrK/xps2JfNJqK7Nz4d07nGrIXrl79ZpgaFGjSTHmyLFVQL+1nRJCCCXLrSW9LN5cuVlUi/36N6ZsToG5B+2bFAoxioLUj1D4Tw//wZTtIPUJAPP2QifYHAr4OGP8CAI4BWEjvft/cfo4FOTYr0PIk6CdrP0SQg4qJbF8GS3/ktPHAa7wYRc3OxBiyKu4fw/gFHDHVgQQQoYJ+JuVX2OcTGdmumE3QnFRH8MktpVc3WtWBxsJpVnFhoDzE48CsKpkLoxdfCSAk9AG2GoER7VdfbrRRhUAAEArXHv4nXOPuwoEACCYYIo9PPh45WMAwH3wAAAAIbfaW237zSfi7hpuEwAAAZSg2PYbQyfG8QqAkpQGDQAAcTe276FuFQAAgzZxH5KwA9o3sVai8YB+t+eYAng8+Z8+9gI1nBoAoO13hnvqTu2JY7tea+t33dlsB50wiTDBJ6IE2BLAPiAnoRGMyU5VSVIfPPkatrPS2abkYyzDKAbP9OLTxKHJo54y4u34heKOvQR4URrfvpoTUQJAgd+3ACiE4FjXX+Egd/BYJx4e00IskKOUAnKwziBOAPAYFjGBEHKAO2AjkAf8kY2DALvODPfOe87t2yTZT3+k9YO3BPM417Gazc6+N/5eN+zu+woiJ/5g/AeHaNLe4dQCnflj0rp6kItAYwJm5g/LpL2jiuqPpn6kQGXfK+FxkPvh9A+PchmonQNiKKUoivY3dzYnibxwFAM5nhEQE6cHWsobQiiL8sse0b5bQAxBKTnYUAjICbz4ctei2y0gBmGUogMUXxBIgvSyR4M+PyAGQiiqL2vmjyNAfsmv/6XCCSJ3JF+Ql4HACwJ/IpaQ2yMsIIYx0jABMEYaJgDGSPNYAK9KFDODcUC2Z/XHAuB5/hXSABxO1/+twa+W8QAAnue3VqvnOO6ErP++R4ZrF2z93v4iTj7bjQfb3aDgyGMRDgiEcHu6v9LG02Mdf7oPuG19na+08fDVMp3BOFxYI5gx0jABMEYaJgDGSMMEwBhpmAAYI80O45Z833/lFi1kMJ6LLMtPr967gwBkWRaEV2lAH4OxF3bsbWT9AIyRhrUBGCMNEwBjpGECYIw0TACMkYYJgDHSMAEwRhomgEMGI+R5HqU0jiI/CJ6IUkAoDcJw6HrebSEJjHGapoQQjDHCGKGUEEoIRgi9pLWQRxnW4XWYUEq//uqL20v3f/df/sv1B/fur23+wR/8fjZjI4RESYIA3L514+tvbnzwo98pl3LXbt6ulMcrY0VCCMdxGGNBECglK/cfrlXrb7x5afnu7XbPTeNoanYu8j2eAxvN/p/97I/FVyp87ITDBHCYpElc3awVC7l6s3n63LkwJf/4D7/gBSEMwumZuUKhEMWJKIo3r1294rl+HF+/ek3XNAAhAHTQH4yVyyiNeSjGMQ6DMF8ora9X86WyMxjEvv+Dt3/Ay5vCES4hOgqwKtDhQiEAkixDACGEnuvNTE8HcTI7O11dX0kxgQAKovhwZWXgDARJmZyanJgcj6KQ8uLi2bNJErdancHAQWmapiiOwkwuL8uSadrFfK7f7xcKhVcn+PbVgAngMBElZaxYqFY3CoU8z/PZXGZldbWQz1t2Jp/LmobB85zrutNT09lsvpDPRVEk8hxGSFMVy7ZKY2OFQr5UKlqWyfPw3t17kBOqa+uiKCqqSgm1LPO4H/H7BhsLdMikaeJ5gW3bAFCEUBRFsiwDQL/68ssLly7LkuD7ga7rcRxLsozSVBTFIAgkSRZFgQIQR5EoipQCSRKDwKcAchBuTWS5fS4JxqHABHAUUEqDINA07RWaPmREYAJgjDSsPGWMNEwAjJGGCYAx0jABMEYaJgDGSMMEwBhpmAAYIw0TAGOkEVhHGGOUYSUAY6RhAmCMNKwKxBhpWAnAGGlYCcAYaZgAGCMNEwBjpGECYIw0TACMkUZ4tRZYZzAOF1YCMEYaVgIwRhomAMZI8/8DJID4UsCn/V8AAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image, *_ = test_data.__getitem__(np.random.randint(len(test_data)))\n", "t = transforms.ToPILImage()\n", "image = image.unsqueeze(0)\n", "t(image.squeeze())" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "header{\n", "btn-inactive,btn-active\n", "}\n", "row{\n", "single{\n", "small-title,text,btn-green\n", "}\n", "}\n", "row{\n", "double{\n", "small-title,text,btn-green\n", "}\n", "double{\n", "small-title,text,btn-green\n", "}\n", "}\n", "row{\n", "quadruple{\n", "small-title,text,btn-green\n", "}\n", "quadruple{\n", "small-title,text,btn-green\n", "}\n", "quadruple{\n", "small-title,text,btn-green\n", "}\n", "quadruple{\n", "small-title,text,btn-green\n", "}\n", "}\n", "\n" ] } ], "source": [ "image = image.cuda()\n", "ct = []\n", "ct.append(vocab.to_vec(' '))\n", "ct.append(vocab.to_vec(''))\n", "output = ''\n", "for i in range(200):\n", " context = torch.tensor(ct).unsqueeze(0).float().cuda()\n", " index = torch.argmax(net(image, context), 2).squeeze()[-1:].squeeze()\n", " v = vocab.to_vocab(int(index))\n", " if v == '':\n", " break\n", " output += v\n", " ct.append(vocab.to_vec(v))\n", "\n", "with open('./compiler/output.gui', 'w') as f:\n", " f.write(output)\n", "\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now from the compiler directory in your terminal run\n", "`python web-compiler.py output.gui`.\n", "This will generate a `output.html` file that you can open in your browser." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.2 64-bit ('pytorch': conda)", "language": "python", "name": "python38264bitpytorchcondaf04cb2303bb94659b54446e023c3cb62" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 4 }