{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from util import UIDataset, Vocabulary\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from model import *\n", "from torchvision import transforms\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "dataset = UIDataset('./dataset/training', 'voc.pkl')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "net = Pix2Code().cuda()\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.Adam(net.parameters(), lr = 0.0001)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss: 2.9524800777435303, Epoch: 0\n", "Loss: 2.8884081840515137, Epoch: 0\n", "Loss: 2.829449415206909, Epoch: 0\n", "Loss: 2.740262269973755, Epoch: 0\n", "Loss: 2.621814489364624, Epoch: 0\n", "Loss: 2.5236337184906006, Epoch: 0\n", "Loss: 2.4583466053009033, Epoch: 0\n", "Loss: 2.4608116149902344, Epoch: 0\n", "Loss: 2.4273080825805664, Epoch: 0\n", "Loss: 2.385077714920044, Epoch: 0\n", "Loss: 2.3915295600891113, Epoch: 0\n", "Loss: 2.3580784797668457, Epoch: 0\n", "Loss: 2.353372812271118, Epoch: 0\n", "Loss: 2.3451170921325684, Epoch: 0\n", "Loss: 2.350393772125244, Epoch: 0\n", "Loss: 2.3572347164154053, Epoch: 0\n", "Loss: 2.3479745388031006, Epoch: 0\n", "Loss: 2.3289999961853027, Epoch: 0\n", "Loss: 2.3202321529388428, Epoch: 0\n", "Loss: 2.318603277206421, Epoch: 0\n", "Loss: 2.3135170936584473, Epoch: 0\n", "Loss: 2.2949447631835938, Epoch: 0\n", "Loss: 2.3082993030548096, Epoch: 0\n", "Loss: 2.3046371936798096, Epoch: 0\n", "Loss: 2.306605339050293, Epoch: 0\n", "Loss: 2.303187847137451, Epoch: 0\n", "Loss: 2.2787742614746094, Epoch: 0\n", "Loss: 2.300393581390381, Epoch: 0\n", "Loss: 2.2998340129852295, Epoch: 0\n", "Loss: 2.2719686031341553, Epoch: 0\n", "Loss: 2.2761404514312744, Epoch: 0\n", "Loss: 2.3393924236297607, Epoch: 0\n", "Loss: 2.2692413330078125, Epoch: 0\n", "Loss: 2.270902395248413, Epoch: 0\n", "Loss: 2.282170534133911, Epoch: 0\n", "Loss: 2.258744955062866, Epoch: 0\n", "Loss: 2.261420965194702, Epoch: 0\n", "Loss: 2.247162103652954, Epoch: 0\n", "Loss: 2.238236427307129, Epoch: 0\n", "Loss: 2.2235021591186523, Epoch: 0\n", "Loss: 2.2508511543273926, Epoch: 0\n", "Loss: 2.2157509326934814, Epoch: 0\n", "Loss: 2.2379965782165527, Epoch: 0\n", "Loss: 2.231006145477295, Epoch: 0\n", "Loss: 2.2043988704681396, Epoch: 0\n", "Loss: 2.2038233280181885, Epoch: 0\n", "Loss: 2.1575405597686768, Epoch: 0\n", "Loss: 2.1855974197387695, Epoch: 0\n", "Loss: 2.156130075454712, Epoch: 0\n", "Loss: 2.144746780395508, Epoch: 0\n", "Loss: 2.1328115463256836, Epoch: 0\n", "Loss: 2.1185996532440186, Epoch: 0\n", "Loss: 2.1840450763702393, Epoch: 0\n", "Loss: 2.0942633152008057, Epoch: 0\n", "Loss: 2.0785911083221436, Epoch: 0\n", "Loss: 2.0678865909576416, Epoch: 0\n", "Loss: 2.010136604309082, Epoch: 0\n", "Loss: 1.9797004461288452, Epoch: 0\n", "Loss: 1.9312745332717896, Epoch: 0\n", "Loss: 1.900559663772583, Epoch: 0\n", "Loss: 1.8531383275985718, Epoch: 0\n", "Loss: 1.9304338693618774, Epoch: 0\n", "Loss: 1.7415088415145874, Epoch: 0\n", "Loss: 1.6876497268676758, Epoch: 0\n", "Loss: 1.7912997007369995, Epoch: 0\n", "Loss: 1.5977957248687744, Epoch: 0\n", "Loss: 1.4403373003005981, Epoch: 0\n", "Loss: 1.4909701347351074, Epoch: 0\n", "Loss: 1.4378795623779297, Epoch: 0\n", "Loss: 1.3532891273498535, Epoch: 0\n", "Loss: 1.5037901401519775, Epoch: 0\n", "Loss: 1.1528806686401367, Epoch: 0\n", "Loss: 1.1900337934494019, Epoch: 0\n", "Loss: 1.2141656875610352, Epoch: 0\n", "Loss: 1.0680357217788696, Epoch: 0\n", "Loss: 1.336559534072876, Epoch: 0\n", "Loss: 1.1178141832351685, Epoch: 0\n", "Loss: 1.019561529159546, Epoch: 0\n", "Loss: 0.9516667723655701, Epoch: 0\n", "Loss: 0.9714117050170898, Epoch: 0\n", "Loss: 0.8897005319595337, Epoch: 0\n", "Loss: 1.470299482345581, Epoch: 0\n", "Loss: 0.8343897461891174, Epoch: 0\n", "Loss: 1.0053107738494873, Epoch: 0\n", "Loss: 0.8614187240600586, Epoch: 0\n", "Loss: 0.8838141560554504, Epoch: 0\n", "Loss: 0.7585976719856262, Epoch: 0\n", "Loss: 0.736517071723938, Epoch: 0\n", "Loss: 0.741409420967102, Epoch: 0\n", "Loss: 0.8206121325492859, Epoch: 0\n", "Loss: 0.7003164291381836, Epoch: 0\n", "Loss: 0.6972914934158325, Epoch: 0\n", "Loss: 0.653584361076355, Epoch: 0\n", "Loss: 0.6935228705406189, Epoch: 0\n", "Loss: 0.9742002487182617, Epoch: 0\n", "Loss: 0.6670038104057312, Epoch: 0\n", "Loss: 0.6938544511795044, Epoch: 0\n", "Loss: 0.6238456964492798, Epoch: 0\n", "Loss: 0.5920361280441284, Epoch: 0\n", "Loss: 0.6854602694511414, Epoch: 0\n", "Loss: 0.6296735405921936, Epoch: 0\n", "Loss: 0.6589140295982361, Epoch: 0\n", "Loss: 0.5916438698768616, Epoch: 0\n", "Loss: 0.870532214641571, Epoch: 0\n", "Loss: 0.5000392198562622, Epoch: 0\n", "Loss: 0.5296663641929626, Epoch: 0\n", "Loss: 0.731133222579956, Epoch: 0\n", "Loss: 0.5028390884399414, Epoch: 0\n", "Loss: 0.5200638175010681, Epoch: 0\n", "Loss: 0.4418269097805023, Epoch: 0\n", "Loss: 0.4933643937110901, Epoch: 0\n", "Loss: 0.5051903128623962, Epoch: 0\n", "Loss: 0.5060564279556274, Epoch: 0\n", "Loss: 0.46932217478752136, Epoch: 0\n", "Loss: 0.6307882070541382, Epoch: 0\n", "Loss: 0.4105120301246643, Epoch: 0\n", "Loss: 0.45868533849716187, Epoch: 0\n", "Loss: 0.4584932029247284, Epoch: 0\n", "Loss: 0.650614857673645, Epoch: 0\n", "Loss: 0.4539167582988739, Epoch: 0\n", "Loss: 0.4140841066837311, Epoch: 0\n", "Loss: 0.4211380183696747, Epoch: 0\n", "Loss: 0.4530402719974518, Epoch: 0\n", "Loss: 0.40992099046707153, Epoch: 0\n", "Loss: 0.45029035210609436, Epoch: 0\n", "Loss: 0.39612260460853577, Epoch: 0\n", "Loss: 0.43665429949760437, Epoch: 0\n", "Loss: 0.3842218220233917, Epoch: 0\n", "Loss: 0.38374099135398865, Epoch: 0\n", "Loss: 0.4136958122253418, Epoch: 0\n", "Loss: 0.44654661417007446, Epoch: 0\n", "Loss: 0.3930183947086334, Epoch: 0\n", "Loss: 0.40135490894317627, Epoch: 0\n", "Loss: 0.34777334332466125, Epoch: 0\n", "Loss: 0.4132601320743561, Epoch: 0\n", "Loss: 0.4124941825866699, Epoch: 0\n", "Loss: 0.42920032143592834, Epoch: 0\n", "Loss: 0.37013471126556396, Epoch: 0\n", "Loss: 0.3565217852592468, Epoch: 0\n", "Loss: 0.35290029644966125, Epoch: 0\n", "Loss: 0.3940131366252899, Epoch: 0\n", "Loss: 0.3507075309753418, Epoch: 0\n", "Loss: 0.340397447347641, Epoch: 0\n", "Loss: 0.34468895196914673, Epoch: 0\n", "Loss: 0.35861828923225403, Epoch: 0\n", "Loss: 0.331664502620697, Epoch: 0\n", "Loss: 0.3724385201931, Epoch: 0\n", "Loss: 0.4620945453643799, Epoch: 0\n", "Loss: 0.40686291456222534, Epoch: 0\n", "Loss: 0.34405651688575745, Epoch: 0\n", "Loss: 0.33751004934310913, Epoch: 0\n", "Loss: 0.33067846298217773, Epoch: 0\n", "Loss: 0.35637813806533813, Epoch: 0\n", "Loss: 0.33678290247917175, Epoch: 0\n", "Loss: 0.3399815857410431, Epoch: 0\n", "Loss: 0.3555727005004883, Epoch: 0\n", "Loss: 0.36221396923065186, Epoch: 0\n", "Loss: 0.31564000248908997, Epoch: 0\n", "Loss: 0.36513498425483704, Epoch: 1\n", "Loss: 0.38518428802490234, Epoch: 1\n", "Loss: 0.39113444089889526, Epoch: 1\n", "Loss: 0.31585872173309326, Epoch: 1\n", "Loss: 0.3462843894958496, Epoch: 1\n", "Loss: 0.31152087450027466, Epoch: 1\n", "Loss: 0.49244046211242676, Epoch: 1\n", "Loss: 0.3516445755958557, Epoch: 1\n", "Loss: 0.31429073214530945, Epoch: 1\n", "Loss: 0.3202592730522156, Epoch: 1\n", "Loss: 0.3343992531299591, Epoch: 1\n", "Loss: 0.3129790127277374, Epoch: 1\n", "Loss: 0.3035936653614044, Epoch: 1\n", "Loss: 0.3117291331291199, Epoch: 1\n", "Loss: 0.34716925024986267, Epoch: 1\n", "Loss: 0.2994453012943268, Epoch: 1\n", "Loss: 0.3337497413158417, Epoch: 1\n", "Loss: 0.29858383536338806, Epoch: 1\n", "Loss: 0.3567463457584381, Epoch: 1\n", "Loss: 0.2959181070327759, Epoch: 1\n", "Loss: 0.2913894057273865, Epoch: 1\n", "Loss: 0.31215083599090576, Epoch: 1\n", "Loss: 0.28617092967033386, Epoch: 1\n", "Loss: 0.31746673583984375, Epoch: 1\n", "Loss: 0.2842235565185547, Epoch: 1\n", "Loss: 0.3054017126560211, Epoch: 1\n", "Loss: 0.2990272045135498, Epoch: 1\n", "Loss: 0.2779654264450073, Epoch: 1\n", "Loss: 0.30389276146888733, Epoch: 1\n", "Loss: 0.3101218044757843, Epoch: 1\n", "Loss: 0.28804558515548706, Epoch: 1\n", "Loss: 0.34255534410476685, Epoch: 1\n", "Loss: 0.27916717529296875, Epoch: 1\n", "Loss: 0.2679106891155243, Epoch: 1\n", "Loss: 0.32168295979499817, Epoch: 1\n", "Loss: 0.2777051627635956, Epoch: 1\n", "Loss: 0.2755148708820343, Epoch: 1\n", "Loss: 0.2674615681171417, Epoch: 1\n", "Loss: 0.2992023825645447, Epoch: 1\n", "Loss: 0.2748924493789673, Epoch: 1\n", "Loss: 0.29646700620651245, Epoch: 1\n", "Loss: 0.28040811419487, Epoch: 1\n", "Loss: 0.30196037888526917, Epoch: 1\n", "Loss: 0.2965075969696045, Epoch: 1\n", "Loss: 0.26315343379974365, Epoch: 1\n", "Loss: 0.251770555973053, Epoch: 1\n", "Loss: 0.26690012216567993, Epoch: 1\n", "Loss: 0.2505774199962616, Epoch: 1\n", "Loss: 0.28790804743766785, Epoch: 1\n", "Loss: 0.2832382023334503, Epoch: 1\n", "Loss: 0.2817041575908661, Epoch: 1\n", "Loss: 0.28491726517677307, Epoch: 1\n", "Loss: 0.301067054271698, Epoch: 1\n", "Loss: 0.263874888420105, Epoch: 1\n", "Loss: 0.24870161712169647, Epoch: 1\n", "Loss: 0.2826177775859833, Epoch: 1\n", "Loss: 0.2598406672477722, Epoch: 1\n", "Loss: 0.25691503286361694, Epoch: 1\n", "Loss: 0.2716769576072693, Epoch: 1\n", "Loss: 0.2626866400241852, Epoch: 1\n", "Loss: 0.28078770637512207, Epoch: 1\n", "Loss: 0.36835792660713196, Epoch: 1\n", "Loss: 0.2390429675579071, Epoch: 1\n", "Loss: 0.2525367736816406, Epoch: 1\n", "Loss: 0.2904892861843109, Epoch: 1\n", "Loss: 0.2849578261375427, Epoch: 1\n", "Loss: 0.2514594793319702, Epoch: 1\n", "Loss: 0.2684769630432129, Epoch: 1\n", "Loss: 0.24913877248764038, Epoch: 1\n", "Loss: 0.26298603415489197, Epoch: 1\n", "Loss: 0.2606211006641388, Epoch: 1\n", "Loss: 0.24511288106441498, Epoch: 1\n", "Loss: 0.25391775369644165, Epoch: 1\n", "Loss: 0.28190696239471436, Epoch: 1\n", "Loss: 0.2384095937013626, Epoch: 1\n", "Loss: 0.25621259212493896, Epoch: 1\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.23397196829319, Epoch: 1\n", "Loss: 0.2452157586812973, Epoch: 1\n", "Loss: 0.25013604760169983, Epoch: 1\n", "Loss: 0.24252459406852722, Epoch: 1\n", "Loss: 0.24472461640834808, Epoch: 1\n", "Loss: 0.3319304883480072, Epoch: 1\n", "Loss: 0.23765243589878082, Epoch: 1\n", "Loss: 0.24847112596035004, Epoch: 1\n", "Loss: 0.2331979125738144, Epoch: 1\n", "Loss: 0.26821067929267883, Epoch: 1\n", "Loss: 0.23324160277843475, Epoch: 1\n", "Loss: 0.2333941012620926, Epoch: 1\n", "Loss: 0.23415108025074005, Epoch: 1\n", "Loss: 0.2484195977449417, Epoch: 1\n", "Loss: 0.23526932299137115, Epoch: 1\n", "Loss: 0.22446300089359283, Epoch: 1\n", "Loss: 0.22437800467014313, Epoch: 1\n", "Loss: 0.2274080067873001, Epoch: 1\n", "Loss: 0.23755177855491638, Epoch: 1\n", "Loss: 0.21210730075836182, Epoch: 1\n", "Loss: 0.2223813533782959, Epoch: 1\n", "Loss: 0.22736792266368866, Epoch: 1\n", "Loss: 0.22397784888744354, Epoch: 1\n", "Loss: 0.2413712739944458, Epoch: 1\n", "Loss: 0.2235811948776245, Epoch: 1\n", "Loss: 0.24524028599262238, Epoch: 1\n", "Loss: 0.22623442113399506, Epoch: 1\n", "Loss: 0.24114780128002167, Epoch: 1\n", "Loss: 0.22427205741405487, Epoch: 1\n", "Loss: 0.2151176482439041, Epoch: 1\n", "Loss: 0.21807488799095154, Epoch: 1\n", "Loss: 0.2142992466688156, Epoch: 1\n", "Loss: 0.21087424457073212, Epoch: 1\n", "Loss: 0.2140428125858307, Epoch: 1\n", "Loss: 0.21240393817424774, Epoch: 1\n", "Loss: 0.20539182424545288, Epoch: 1\n", "Loss: 0.20411120355129242, Epoch: 1\n", "Loss: 0.21139389276504517, Epoch: 1\n", "Loss: 0.24850775301456451, Epoch: 1\n", "Loss: 0.21610818803310394, Epoch: 1\n", "Loss: 0.20004600286483765, Epoch: 1\n", "Loss: 0.19582688808441162, Epoch: 1\n", "Loss: 0.23555868864059448, Epoch: 1\n", "Loss: 0.20441941916942596, Epoch: 1\n", "Loss: 0.20177888870239258, Epoch: 1\n", "Loss: 0.19751212000846863, Epoch: 1\n", "Loss: 0.20864750444889069, Epoch: 1\n", "Loss: 0.20159509778022766, Epoch: 1\n", "Loss: 0.2005327045917511, Epoch: 1\n", "Loss: 0.1999230831861496, Epoch: 1\n", "Loss: 0.2040192186832428, Epoch: 1\n", "Loss: 0.19645045697689056, Epoch: 1\n", "Loss: 0.199203759431839, Epoch: 1\n", "Loss: 0.20018835365772247, Epoch: 1\n", "Loss: 0.19527064263820648, Epoch: 1\n", "Loss: 0.19556036591529846, Epoch: 1\n", "Loss: 0.19404765963554382, Epoch: 1\n", "Loss: 0.18973395228385925, Epoch: 1\n", "Loss: 0.19332443177700043, Epoch: 1\n", "Loss: 0.1943163424730301, Epoch: 1\n", "Loss: 0.22645233571529388, Epoch: 1\n", "Loss: 0.18985994160175323, Epoch: 1\n", "Loss: 0.18457040190696716, Epoch: 1\n", "Loss: 0.1887940615415573, Epoch: 1\n", "Loss: 0.184039905667305, Epoch: 1\n", "Loss: 0.18834181129932404, Epoch: 1\n", "Loss: 0.19446147978305817, Epoch: 1\n", "Loss: 0.1980251669883728, Epoch: 1\n", "Loss: 0.17985910177230835, Epoch: 1\n", "Loss: 0.1837856024503708, Epoch: 1\n", "Loss: 0.187068909406662, Epoch: 1\n", "Loss: 0.21451883018016815, Epoch: 1\n", "Loss: 0.18576252460479736, Epoch: 1\n", "Loss: 0.1763405203819275, Epoch: 1\n", "Loss: 0.1780242919921875, Epoch: 1\n", "Loss: 0.17861078679561615, Epoch: 1\n", "Loss: 0.1907288134098053, Epoch: 1\n", "Loss: 0.1792900711297989, Epoch: 1\n", "Loss: 0.18664227426052094, Epoch: 1\n", "Loss: 0.17793485522270203, Epoch: 1\n", "Loss: 0.17487949132919312, Epoch: 1\n", "Loss: 0.20927435159683228, Epoch: 1\n", "Loss: 0.17920136451721191, Epoch: 2\n", "Loss: 0.17775531113147736, Epoch: 2\n", "Loss: 0.17145593464374542, Epoch: 2\n", "Loss: 0.17465360462665558, Epoch: 2\n", "Loss: 0.18824218213558197, Epoch: 2\n", "Loss: 0.17750504612922668, Epoch: 2\n", "Loss: 0.19186504185199738, Epoch: 2\n", "Loss: 0.16394080221652985, Epoch: 2\n", "Loss: 0.1819714903831482, Epoch: 2\n", "Loss: 0.16088007390499115, Epoch: 2\n", "Loss: 0.18525062501430511, Epoch: 2\n", "Loss: 0.17226216197013855, Epoch: 2\n", "Loss: 0.17414595186710358, Epoch: 2\n", "Loss: 0.18195168673992157, Epoch: 2\n", "Loss: 0.17499519884586334, Epoch: 2\n", "Loss: 0.17727644741535187, Epoch: 2\n", "Loss: 0.17118574678897858, Epoch: 2\n", "Loss: 0.17873334884643555, Epoch: 2\n", "Loss: 0.1621459275484085, Epoch: 2\n", "Loss: 0.17832708358764648, Epoch: 2\n", "Loss: 0.16719700396060944, Epoch: 2\n", "Loss: 0.17284554243087769, Epoch: 2\n", "Loss: 0.16878937184810638, Epoch: 2\n", "Loss: 0.16516879200935364, Epoch: 2\n", "Loss: 0.16708654165267944, Epoch: 2\n", "Loss: 0.1831068992614746, Epoch: 2\n", "Loss: 0.18170510232448578, Epoch: 2\n", "Loss: 0.1666964292526245, Epoch: 2\n", "Loss: 0.18964578211307526, Epoch: 2\n", "Loss: 0.16580110788345337, Epoch: 2\n", "Loss: 0.1696055829524994, Epoch: 2\n", "Loss: 0.18262037634849548, Epoch: 2\n", "Loss: 0.16144578158855438, Epoch: 2\n", "Loss: 0.16571661829948425, Epoch: 2\n", "Loss: 0.16096633672714233, Epoch: 2\n", "Loss: 0.16165024042129517, Epoch: 2\n", "Loss: 0.16471704840660095, Epoch: 2\n", "Loss: 0.16848206520080566, Epoch: 2\n", "Loss: 0.17229224741458893, Epoch: 2\n", "Loss: 0.1488664597272873, Epoch: 2\n", "Loss: 0.16164709627628326, Epoch: 2\n", "Loss: 0.1660623997449875, Epoch: 2\n", "Loss: 0.16126243770122528, Epoch: 2\n", "Loss: 0.16471469402313232, Epoch: 2\n", "Loss: 0.15973404049873352, Epoch: 2\n", "Loss: 0.15950854122638702, Epoch: 2\n", "Loss: 0.14817242324352264, Epoch: 2\n", "Loss: 0.15562114119529724, Epoch: 2\n", "Loss: 0.15278998017311096, Epoch: 2\n", "Loss: 0.15652954578399658, Epoch: 2\n", "Loss: 0.1548881232738495, Epoch: 2\n", "Loss: 0.1577402502298355, Epoch: 2\n", "Loss: 0.1587544083595276, Epoch: 2\n", "Loss: 0.16029223799705505, Epoch: 2\n", "Loss: 0.15766172111034393, Epoch: 2\n", "Loss: 0.1566898226737976, Epoch: 2\n", "Loss: 0.15709587931632996, Epoch: 2\n", "Loss: 0.15407609939575195, Epoch: 2\n", "Loss: 0.15885479748249054, Epoch: 2\n", "Loss: 0.16054479777812958, Epoch: 2\n", "Loss: 0.16161851584911346, Epoch: 2\n", "Loss: 0.19581153988838196, Epoch: 2\n", "Loss: 0.1506919115781784, Epoch: 2\n", "Loss: 0.15983256697654724, Epoch: 2\n", "Loss: 0.1840815395116806, Epoch: 2\n", "Loss: 0.15997251868247986, Epoch: 2\n", "Loss: 0.158875972032547, Epoch: 2\n", "Loss: 0.1611945927143097, Epoch: 2\n", "Loss: 0.15246137976646423, Epoch: 2\n", "Loss: 0.1503104567527771, Epoch: 2\n", "Loss: 0.1955866515636444, Epoch: 2\n", "Loss: 0.14065603911876678, Epoch: 2\n", "Loss: 0.16270065307617188, Epoch: 2\n", "Loss: 0.1538032591342926, Epoch: 2\n", "Loss: 0.15558446943759918, Epoch: 2\n", "Loss: 0.177334263920784, Epoch: 2\n", "Loss: 0.15660960972309113, Epoch: 2\n", "Loss: 0.14623355865478516, Epoch: 2\n", "Loss: 0.1487160623073578, Epoch: 2\n", "Loss: 0.14769864082336426, Epoch: 2\n", "Loss: 0.15187585353851318, Epoch: 2\n", "Loss: 0.18615028262138367, Epoch: 2\n", "Loss: 0.15663014352321625, Epoch: 2\n", "Loss: 0.1567113846540451, Epoch: 2\n", "Loss: 0.1475883573293686, Epoch: 2\n", "Loss: 0.14988942444324493, Epoch: 2\n", "Loss: 0.1529097557067871, Epoch: 2\n", "Loss: 0.14954881370067596, Epoch: 2\n", "Loss: 0.14585697650909424, Epoch: 2\n", "Loss: 0.1537933349609375, Epoch: 2\n", "Loss: 0.16131213307380676, Epoch: 2\n", "Loss: 0.14611506462097168, Epoch: 2\n", "Loss: 0.15565554797649384, Epoch: 2\n", "Loss: 0.14928916096687317, Epoch: 2\n", "Loss: 0.18597358465194702, Epoch: 2\n", "Loss: 0.1498081088066101, Epoch: 2\n", "Loss: 0.1543579250574112, Epoch: 2\n", "Loss: 0.1561511605978012, Epoch: 2\n", "Loss: 0.15706948935985565, Epoch: 2\n", "Loss: 0.16255070269107819, Epoch: 2\n", "Loss: 0.15505479276180267, Epoch: 2\n", "Loss: 0.15793101489543915, Epoch: 2\n", "Loss: 0.15751025080680847, Epoch: 2\n", "Loss: 0.17496031522750854, Epoch: 2\n", "Loss: 0.15238243341445923, Epoch: 2\n", "Loss: 0.1480635404586792, Epoch: 2\n", "Loss: 0.16778671741485596, Epoch: 2\n", "Loss: 0.1499747931957245, Epoch: 2\n", "Loss: 0.14654149115085602, Epoch: 2\n", "Loss: 0.15334898233413696, Epoch: 2\n", "Loss: 0.14312013983726501, Epoch: 2\n", "Loss: 0.14889495074748993, Epoch: 2\n", "Loss: 0.15227057039737701, Epoch: 2\n", "Loss: 0.15047228336334229, Epoch: 2\n", "Loss: 0.1697094738483429, Epoch: 2\n", "Loss: 0.14746415615081787, Epoch: 2\n", "Loss: 0.14284475147724152, Epoch: 2\n", "Loss: 0.14408795535564423, Epoch: 2\n", "Loss: 0.1655958741903305, Epoch: 2\n", "Loss: 0.15247742831707, Epoch: 2\n", "Loss: 0.15246184170246124, Epoch: 2\n", "Loss: 0.1515989601612091, Epoch: 2\n", "Loss: 0.14632681012153625, Epoch: 2\n", "Loss: 0.15054377913475037, Epoch: 2\n", "Loss: 0.15041185915470123, Epoch: 2\n", "Loss: 0.15458422899246216, Epoch: 2\n", "Loss: 0.14498606324195862, Epoch: 2\n", "Loss: 0.1463392674922943, Epoch: 2\n", "Loss: 0.14906661212444305, Epoch: 2\n", "Loss: 0.15188321471214294, Epoch: 2\n", "Loss: 0.14843648672103882, Epoch: 2\n", "Loss: 0.1495736539363861, Epoch: 2\n", "Loss: 0.1508703976869583, Epoch: 2\n", "Loss: 0.1415024846792221, Epoch: 2\n", "Loss: 0.14888466894626617, Epoch: 2\n", "Loss: 0.14863824844360352, Epoch: 2\n", "Loss: 0.1804545372724533, Epoch: 2\n", "Loss: 0.14639806747436523, Epoch: 2\n", "Loss: 0.14789406955242157, Epoch: 2\n", "Loss: 0.1517217457294464, Epoch: 2\n", "Loss: 0.15233184397220612, Epoch: 2\n", "Loss: 0.14604727923870087, Epoch: 2\n", "Loss: 0.14814278483390808, Epoch: 2\n", "Loss: 0.14410676062107086, Epoch: 2\n", "Loss: 0.14756864309310913, Epoch: 2\n", "Loss: 0.15017764270305634, Epoch: 2\n", "Loss: 0.15275132656097412, Epoch: 2\n", "Loss: 0.15587218105793, Epoch: 2\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.14909355342388153, Epoch: 2\n", "Loss: 0.14263498783111572, Epoch: 2\n", "Loss: 0.14612025022506714, Epoch: 2\n", "Loss: 0.147268146276474, Epoch: 2\n", "Loss: 0.14908260107040405, Epoch: 2\n", "Loss: 0.13998505473136902, Epoch: 2\n", "Loss: 0.14733323454856873, Epoch: 2\n", "Loss: 0.14932139217853546, Epoch: 2\n", "Loss: 0.14685842394828796, Epoch: 2\n", "Loss: 0.14514578878879547, Epoch: 2\n", "Loss: 0.14932207763195038, Epoch: 3\n", "Loss: 0.14646084606647491, Epoch: 3\n", "Loss: 0.14726775884628296, Epoch: 3\n", "Loss: 0.1431538462638855, Epoch: 3\n", "Loss: 0.15928220748901367, Epoch: 3\n", "Loss: 0.14593127369880676, Epoch: 3\n", "Loss: 0.16137444972991943, Epoch: 3\n", "Loss: 0.14118261635303497, Epoch: 3\n", "Loss: 0.1484784036874771, Epoch: 3\n", "Loss: 0.13823699951171875, Epoch: 3\n", "Loss: 0.1517200618982315, Epoch: 3\n", "Loss: 0.14846888184547424, Epoch: 3\n", "Loss: 0.14220844209194183, Epoch: 3\n", "Loss: 0.15510502457618713, Epoch: 3\n", "Loss: 0.14475488662719727, Epoch: 3\n", "Loss: 0.14743365347385406, Epoch: 3\n", "Loss: 0.1498083770275116, Epoch: 3\n", "Loss: 0.15431958436965942, Epoch: 3\n", "Loss: 0.1425355225801468, Epoch: 3\n", "Loss: 0.14830829203128815, Epoch: 3\n", "Loss: 0.14888258278369904, Epoch: 3\n", "Loss: 0.14516624808311462, Epoch: 3\n", "Loss: 0.14695410430431366, Epoch: 3\n", "Loss: 0.1333359032869339, Epoch: 3\n", "Loss: 0.14476899802684784, Epoch: 3\n", "Loss: 0.13775229454040527, Epoch: 3\n", "Loss: 0.14642129838466644, Epoch: 3\n", "Loss: 0.14050345122814178, Epoch: 3\n", "Loss: 0.14062468707561493, Epoch: 3\n", "Loss: 0.14814360439777374, Epoch: 3\n", "Loss: 0.14132237434387207, Epoch: 3\n", "Loss: 0.16100995242595673, Epoch: 3\n", "Loss: 0.14359651505947113, Epoch: 3\n", "Loss: 0.14262162148952484, Epoch: 3\n", "Loss: 0.13580752909183502, Epoch: 3\n", "Loss: 0.14163215458393097, Epoch: 3\n", "Loss: 0.15252500772476196, Epoch: 3\n", "Loss: 0.1420316845178604, Epoch: 3\n", "Loss: 0.1503005027770996, Epoch: 3\n", "Loss: 0.13034197688102722, Epoch: 3\n", "Loss: 0.14875632524490356, Epoch: 3\n", "Loss: 0.1441657543182373, Epoch: 3\n", "Loss: 0.14227809011936188, Epoch: 3\n", "Loss: 0.14413177967071533, Epoch: 3\n", "Loss: 0.14078521728515625, Epoch: 3\n", "Loss: 0.1412949413061142, Epoch: 3\n", "Loss: 0.14173543453216553, Epoch: 3\n", "Loss: 0.13746042549610138, Epoch: 3\n", "Loss: 0.1389954835176468, Epoch: 3\n", "Loss: 0.21365982294082642, Epoch: 3\n", "Loss: 0.13948048651218414, Epoch: 3\n", "Loss: 0.2131626307964325, Epoch: 3\n", "Loss: 0.31315305829048157, Epoch: 3\n", "Loss: 0.18220385909080505, Epoch: 3\n", "Loss: 0.44851815700531006, Epoch: 3\n", "Loss: 0.357102632522583, Epoch: 3\n", "Loss: 0.3697846233844757, Epoch: 3\n", "Loss: 0.2790517807006836, Epoch: 3\n", "Loss: 0.2529226243495941, Epoch: 3\n", "Loss: 0.23575733602046967, Epoch: 3\n", "Loss: 0.19847509264945984, Epoch: 3\n", "Loss: 0.1706792116165161, Epoch: 3\n", "Loss: 0.2386464625597, Epoch: 3\n", "Loss: 0.19907152652740479, Epoch: 3\n", "Loss: 0.21236853301525116, Epoch: 3\n", "Loss: 0.2005978673696518, Epoch: 3\n", "Loss: 0.2367972433567047, Epoch: 3\n", "Loss: 0.17748074233531952, Epoch: 3\n", "Loss: 0.1794709414243698, Epoch: 3\n", "Loss: 0.19476911425590515, Epoch: 3\n", "Loss: 0.23755262792110443, Epoch: 3\n", "Loss: 0.18582652509212494, Epoch: 3\n", "Loss: 0.1880793571472168, Epoch: 3\n", "Loss: 0.21453958749771118, Epoch: 3\n", "Loss: 0.17659011483192444, Epoch: 3\n", "Loss: 0.18438029289245605, Epoch: 3\n", "Loss: 0.16922736167907715, Epoch: 3\n", "Loss: 0.16512253880500793, Epoch: 3\n", "Loss: 0.16392676532268524, Epoch: 3\n", "Loss: 0.16824956238269806, Epoch: 3\n", "Loss: 0.17162224650382996, Epoch: 3\n", "Loss: 0.18260185420513153, Epoch: 3\n", "Loss: 0.1664547324180603, Epoch: 3\n", "Loss: 0.16220341622829437, Epoch: 3\n", "Loss: 0.16524864733219147, Epoch: 3\n", "Loss: 0.19387678802013397, Epoch: 3\n", "Loss: 0.15860770642757416, Epoch: 3\n", "Loss: 0.15186947584152222, Epoch: 3\n", "Loss: 0.1443750560283661, Epoch: 3\n", "Loss: 0.1585564911365509, Epoch: 3\n", "Loss: 0.17480601370334625, Epoch: 3\n", "Loss: 0.16876228153705597, Epoch: 3\n", "Loss: 0.15675309300422668, Epoch: 3\n", "Loss: 0.14876218140125275, Epoch: 3\n", "Loss: 0.19663038849830627, Epoch: 3\n", "Loss: 0.14992070198059082, Epoch: 3\n", "Loss: 0.166384756565094, Epoch: 3\n", "Loss: 0.15551204979419708, Epoch: 3\n", "Loss: 0.16162584722042084, Epoch: 3\n", "Loss: 0.16103392839431763, Epoch: 3\n", "Loss: 0.15038976073265076, Epoch: 3\n", "Loss: 0.1739073544740677, Epoch: 3\n", "Loss: 0.1523425579071045, Epoch: 3\n", "Loss: 0.1695767343044281, Epoch: 3\n", "Loss: 0.15175601840019226, Epoch: 3\n", "Loss: 0.14892224967479706, Epoch: 3\n", "Loss: 0.16292822360992432, Epoch: 3\n", "Loss: 0.14466966688632965, Epoch: 3\n", "Loss: 0.14638815820217133, Epoch: 3\n", "Loss: 0.14238741993904114, Epoch: 3\n", "Loss: 0.1467553973197937, Epoch: 3\n", "Loss: 0.1429072618484497, Epoch: 3\n", "Loss: 0.14860624074935913, Epoch: 3\n", "Loss: 0.14519362151622772, Epoch: 3\n", "Loss: 0.17172355949878693, Epoch: 3\n", "Loss: 0.14280566573143005, Epoch: 3\n", "Loss: 0.14112864434719086, Epoch: 3\n", "Loss: 0.14171293377876282, Epoch: 3\n", "Loss: 0.164906844496727, Epoch: 3\n", "Loss: 0.15102867782115936, Epoch: 3\n", "Loss: 0.14635801315307617, Epoch: 3\n", "Loss: 0.14682762324810028, Epoch: 3\n", "Loss: 0.14403878152370453, Epoch: 3\n", "Loss: 0.14384980499744415, Epoch: 3\n", "Loss: 0.15273280441761017, Epoch: 3\n", "Loss: 0.15245644748210907, Epoch: 3\n", "Loss: 0.14466692507266998, Epoch: 3\n", "Loss: 0.14336150884628296, Epoch: 3\n", "Loss: 0.14216336607933044, Epoch: 3\n", "Loss: 0.14766675233840942, Epoch: 3\n", "Loss: 0.1462327241897583, Epoch: 3\n", "Loss: 0.14618737995624542, Epoch: 3\n", "Loss: 0.14568249881267548, Epoch: 3\n", "Loss: 0.1387554109096527, Epoch: 3\n", "Loss: 0.13994747400283813, Epoch: 3\n", "Loss: 0.13912895321846008, Epoch: 3\n", "Loss: 0.1879383772611618, Epoch: 3\n", "Loss: 0.1433798223733902, Epoch: 3\n", "Loss: 0.14359787106513977, Epoch: 3\n", "Loss: 0.14942362904548645, Epoch: 3\n", "Loss: 0.14841613173484802, Epoch: 3\n", "Loss: 0.14370928704738617, Epoch: 3\n", "Loss: 0.1471216380596161, Epoch: 3\n", "Loss: 0.14228633046150208, Epoch: 3\n", "Loss: 0.14556394517421722, Epoch: 3\n", "Loss: 0.1456202119588852, Epoch: 3\n", "Loss: 0.1465320587158203, Epoch: 3\n", "Loss: 0.15671274065971375, Epoch: 3\n", "Loss: 0.14550618827342987, Epoch: 3\n", "Loss: 0.13757435977458954, Epoch: 3\n", "Loss: 0.14069128036499023, Epoch: 3\n", "Loss: 0.14401018619537354, Epoch: 3\n", "Loss: 0.14526163041591644, Epoch: 3\n", "Loss: 0.1361975520849228, Epoch: 3\n", "Loss: 0.1407855749130249, Epoch: 3\n", "Loss: 0.1457691341638565, Epoch: 3\n", "Loss: 0.145160973072052, Epoch: 3\n", "Loss: 0.1382274031639099, Epoch: 3\n", "Loss: 0.14645111560821533, Epoch: 4\n", "Loss: 0.14706067740917206, Epoch: 4\n", "Loss: 0.14455042779445648, Epoch: 4\n", "Loss: 0.13920459151268005, Epoch: 4\n", "Loss: 0.15357601642608643, Epoch: 4\n", "Loss: 0.14392921328544617, Epoch: 4\n", "Loss: 0.1616637110710144, Epoch: 4\n", "Loss: 0.13883942365646362, Epoch: 4\n", "Loss: 0.13806527853012085, Epoch: 4\n", "Loss: 0.13434778153896332, Epoch: 4\n", "Loss: 0.14757214486598969, Epoch: 4\n", "Loss: 0.14733949303627014, Epoch: 4\n", "Loss: 0.1386975646018982, Epoch: 4\n", "Loss: 0.15403154492378235, Epoch: 4\n", "Loss: 0.1409301906824112, Epoch: 4\n", "Loss: 0.143764927983284, Epoch: 4\n", "Loss: 0.142208531498909, Epoch: 4\n", "Loss: 0.14942139387130737, Epoch: 4\n", "Loss: 0.14091135561466217, Epoch: 4\n", "Loss: 0.14473268389701843, Epoch: 4\n", "Loss: 0.14563648402690887, Epoch: 4\n", "Loss: 0.13751356303691864, Epoch: 4\n", "Loss: 0.13772264122962952, Epoch: 4\n", "Loss: 0.13282646238803864, Epoch: 4\n", "Loss: 0.14611919224262238, Epoch: 4\n", "Loss: 0.13140395283699036, Epoch: 4\n", "Loss: 0.14141005277633667, Epoch: 4\n", "Loss: 0.13362324237823486, Epoch: 4\n", "Loss: 0.13560856878757477, Epoch: 4\n", "Loss: 0.1455288678407669, Epoch: 4\n", "Loss: 0.13728894293308258, Epoch: 4\n", "Loss: 0.16653631627559662, Epoch: 4\n", "Loss: 0.1408710777759552, Epoch: 4\n", "Loss: 0.13909664750099182, Epoch: 4\n", "Loss: 0.13380306959152222, Epoch: 4\n", "Loss: 0.1373494416475296, Epoch: 4\n", "Loss: 0.14459766447544098, Epoch: 4\n", "Loss: 0.13684412837028503, Epoch: 4\n", "Loss: 0.13902482390403748, Epoch: 4\n", "Loss: 0.1350281685590744, Epoch: 4\n", "Loss: 0.14071603119373322, Epoch: 4\n", "Loss: 0.13675397634506226, Epoch: 4\n", "Loss: 0.14031022787094116, Epoch: 4\n", "Loss: 0.13689054548740387, Epoch: 4\n", "Loss: 0.13759593665599823, Epoch: 4\n", "Loss: 0.1371559202671051, Epoch: 4\n", "Loss: 0.13737107813358307, Epoch: 4\n", "Loss: 0.13100741803646088, Epoch: 4\n", "Loss: 0.1358410120010376, Epoch: 4\n", "Loss: 0.1382497102022171, Epoch: 4\n", "Loss: 0.13645248115062714, Epoch: 4\n", "Loss: 0.1361786127090454, Epoch: 4\n", "Loss: 0.13501687347888947, Epoch: 4\n", "Loss: 0.1417567878961563, Epoch: 4\n", "Loss: 0.14279848337173462, Epoch: 4\n", "Loss: 0.1270349621772766, Epoch: 4\n", "Loss: 0.1460144817829132, Epoch: 4\n", "Loss: 0.1387779265642166, Epoch: 4\n", "Loss: 0.1424030065536499, Epoch: 4\n", "Loss: 0.14144709706306458, Epoch: 4\n", "Loss: 0.13189923763275146, Epoch: 4\n", "Loss: 0.18845082819461823, Epoch: 4\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.12886185944080353, Epoch: 4\n", "Loss: 0.137564018368721, Epoch: 4\n", "Loss: 0.16131561994552612, Epoch: 4\n", "Loss: 0.1416606456041336, Epoch: 4\n", "Loss: 0.14148710668087006, Epoch: 4\n", "Loss: 0.14193467795848846, Epoch: 4\n", "Loss: 0.1345425546169281, Epoch: 4\n", "Loss: 0.11765842884778976, Epoch: 4\n", "Loss: 0.19954945147037506, Epoch: 4\n", "Loss: 0.12320726364850998, Epoch: 4\n", "Loss: 0.1290886402130127, Epoch: 4\n", "Loss: 0.13515831530094147, Epoch: 4\n", "Loss: 0.1401388943195343, Epoch: 4\n", "Loss: 0.17238909006118774, Epoch: 4\n", "Loss: 0.13727407157421112, Epoch: 4\n", "Loss: 0.1388918161392212, Epoch: 4\n", "Loss: 0.12661297619342804, Epoch: 4\n", "Loss: 0.1401134878396988, Epoch: 4\n", "Loss: 0.1431579738855362, Epoch: 4\n", "Loss: 0.19006170332431793, Epoch: 4\n", "Loss: 0.14119674265384674, Epoch: 4\n", "Loss: 0.1283990442752838, Epoch: 4\n", "Loss: 0.13079889118671417, Epoch: 4\n", "Loss: 0.1374884694814682, Epoch: 4\n", "Loss: 0.13765540719032288, Epoch: 4\n", "Loss: 0.11636094003915787, Epoch: 4\n", "Loss: 0.13191860914230347, Epoch: 4\n", "Loss: 0.13867181539535522, Epoch: 4\n", "Loss: 0.14593051373958588, Epoch: 4\n", "Loss: 0.132233127951622, Epoch: 4\n", "Loss: 0.13886763155460358, Epoch: 4\n", "Loss: 0.14014451205730438, Epoch: 4\n", "Loss: 0.17880161106586456, Epoch: 4\n", "Loss: 0.1320728063583374, Epoch: 4\n", "Loss: 0.1291753053665161, Epoch: 4\n", "Loss: 0.138149231672287, Epoch: 4\n", "Loss: 0.1335010826587677, Epoch: 4\n", "Loss: 0.1274981051683426, Epoch: 4\n", "Loss: 0.13696089386940002, Epoch: 4\n", "Loss: 0.13833221793174744, Epoch: 4\n", "Loss: 0.144839346408844, Epoch: 4\n", "Loss: 0.15894193947315216, Epoch: 4\n", "Loss: 0.15039022266864777, Epoch: 4\n", "Loss: 0.13267137110233307, Epoch: 4\n", "Loss: 0.1594497561454773, Epoch: 4\n", "Loss: 0.12029702961444855, Epoch: 4\n", "Loss: 0.12143255770206451, Epoch: 4\n", "Loss: 0.13340750336647034, Epoch: 4\n", "Loss: 0.1255176067352295, Epoch: 4\n", "Loss: 0.12960122525691986, Epoch: 4\n", "Loss: 0.13134732842445374, Epoch: 4\n", "Loss: 0.13162006437778473, Epoch: 4\n", "Loss: 0.17416930198669434, Epoch: 4\n", "Loss: 0.12281976640224457, Epoch: 4\n", "Loss: 0.1262483447790146, Epoch: 4\n", "Loss: 0.13074268400669098, Epoch: 4\n", "Loss: 0.14798814058303833, Epoch: 4\n", "Loss: 0.13607388734817505, Epoch: 4\n", "Loss: 0.1388680636882782, Epoch: 4\n", "Loss: 0.13348595798015594, Epoch: 4\n", "Loss: 0.1340729147195816, Epoch: 4\n", "Loss: 0.12516005337238312, Epoch: 4\n", "Loss: 0.138189896941185, Epoch: 4\n", "Loss: 0.13944798707962036, Epoch: 4\n", "Loss: 0.12049956619739532, Epoch: 4\n", "Loss: 0.1304030865430832, Epoch: 4\n", "Loss: 0.1271418333053589, Epoch: 4\n", "Loss: 0.129023939371109, Epoch: 4\n", "Loss: 0.13616971671581268, Epoch: 4\n", "Loss: 0.1296188086271286, Epoch: 4\n", "Loss: 0.13445493578910828, Epoch: 4\n", "Loss: 0.12608228623867035, Epoch: 4\n", "Loss: 0.1271897405385971, Epoch: 4\n", "Loss: 0.12908056378364563, Epoch: 4\n", "Loss: 0.16020454466342926, Epoch: 4\n", "Loss: 0.127391055226326, Epoch: 4\n", "Loss: 0.1296154111623764, Epoch: 4\n", "Loss: 0.13318374752998352, Epoch: 4\n", "Loss: 0.1316712349653244, Epoch: 4\n", "Loss: 0.12991346418857574, Epoch: 4\n", "Loss: 0.1342085301876068, Epoch: 4\n", "Loss: 0.12700845301151276, Epoch: 4\n", "Loss: 0.1335075944662094, Epoch: 4\n", "Loss: 0.12601137161254883, Epoch: 4\n", "Loss: 0.134405255317688, Epoch: 4\n", "Loss: 0.11723560839891434, Epoch: 4\n", "Loss: 0.12238363921642303, Epoch: 4\n", "Loss: 0.12179684638977051, Epoch: 4\n", "Loss: 0.12812206149101257, Epoch: 4\n", "Loss: 0.13349272310733795, Epoch: 4\n", "Loss: 0.13319160044193268, Epoch: 4\n", "Loss: 0.11952850967645645, Epoch: 4\n", "Loss: 0.1268928050994873, Epoch: 4\n", "Loss: 0.13243825733661652, Epoch: 4\n", "Loss: 0.13336098194122314, Epoch: 4\n", "Loss: 0.14038413763046265, Epoch: 4\n", "Loss: 0.1249275803565979, Epoch: 5\n", "Loss: 0.12349969148635864, Epoch: 5\n", "Loss: 0.14533382654190063, Epoch: 5\n", "Loss: 0.1261834055185318, Epoch: 5\n", "Loss: 0.14619743824005127, Epoch: 5\n", "Loss: 0.13299092650413513, Epoch: 5\n", "Loss: 0.1651540845632553, Epoch: 5\n", "Loss: 0.13018113374710083, Epoch: 5\n", "Loss: 0.12087367475032806, Epoch: 5\n", "Loss: 0.1217261552810669, Epoch: 5\n", "Loss: 0.13018827140331268, Epoch: 5\n", "Loss: 0.13194221258163452, Epoch: 5\n", "Loss: 0.12210225313901901, Epoch: 5\n", "Loss: 0.13475942611694336, Epoch: 5\n", "Loss: 0.13920359313488007, Epoch: 5\n", "Loss: 0.13548718392848969, Epoch: 5\n", "Loss: 0.12695132195949554, Epoch: 5\n", "Loss: 0.13285671174526215, Epoch: 5\n", "Loss: 0.12065783143043518, Epoch: 5\n", "Loss: 0.1276872307062149, Epoch: 5\n", "Loss: 0.13304223120212555, Epoch: 5\n", "Loss: 0.12204690277576447, Epoch: 5\n", "Loss: 0.12558265030384064, Epoch: 5\n", "Loss: 0.10881917923688889, Epoch: 5\n", "Loss: 0.12909428775310516, Epoch: 5\n", "Loss: 0.11289199441671371, Epoch: 5\n", "Loss: 0.12797780334949493, Epoch: 5\n", "Loss: 0.1226205825805664, Epoch: 5\n", "Loss: 0.11554985493421555, Epoch: 5\n", "Loss: 0.1257942020893097, Epoch: 5\n", "Loss: 0.13876764476299286, Epoch: 5\n", "Loss: 0.15175792574882507, Epoch: 5\n", "Loss: 0.12390465289354324, Epoch: 5\n", "Loss: 0.1710795760154724, Epoch: 5\n", "Loss: 0.1199038103222847, Epoch: 5\n", "Loss: 0.1343294382095337, Epoch: 5\n", "Loss: 0.1328621208667755, Epoch: 5\n", "Loss: 0.11477426439523697, Epoch: 5\n", "Loss: 0.12490951269865036, Epoch: 5\n", "Loss: 0.11478014290332794, Epoch: 5\n", "Loss: 0.13298383355140686, Epoch: 5\n", "Loss: 0.13167645037174225, Epoch: 5\n", "Loss: 0.12593436241149902, Epoch: 5\n", "Loss: 0.13232290744781494, Epoch: 5\n", "Loss: 0.13195490837097168, Epoch: 5\n", "Loss: 0.12542849779129028, Epoch: 5\n", "Loss: 0.13812699913978577, Epoch: 5\n", "Loss: 0.12022323906421661, Epoch: 5\n", "Loss: 0.12176539748907089, Epoch: 5\n", "Loss: 0.12790079414844513, Epoch: 5\n", "Loss: 0.12351775169372559, Epoch: 5\n", "Loss: 0.11888983845710754, Epoch: 5\n", "Loss: 0.12403333187103271, Epoch: 5\n", "Loss: 0.13656170666217804, Epoch: 5\n", "Loss: 0.13339126110076904, Epoch: 5\n", "Loss: 0.11131792515516281, Epoch: 5\n", "Loss: 0.12859481573104858, Epoch: 5\n", "Loss: 0.13500623404979706, Epoch: 5\n", "Loss: 0.12879762053489685, Epoch: 5\n", "Loss: 0.1277874857187271, Epoch: 5\n", "Loss: 0.11442019045352936, Epoch: 5\n", "Loss: 0.1521713137626648, Epoch: 5\n", "Loss: 0.1213955506682396, Epoch: 5\n", "Loss: 0.13298127055168152, Epoch: 5\n", "Loss: 0.14172470569610596, Epoch: 5\n", "Loss: 0.13587993383407593, Epoch: 5\n", "Loss: 0.14074204862117767, Epoch: 5\n", "Loss: 0.13420546054840088, Epoch: 5\n", "Loss: 0.12752939760684967, Epoch: 5\n", "Loss: 0.10668251663446426, Epoch: 5\n", "Loss: 0.19405558705329895, Epoch: 5\n", "Loss: 0.12440341711044312, Epoch: 5\n", "Loss: 0.12243395298719406, Epoch: 5\n", "Loss: 0.12695355713367462, Epoch: 5\n", "Loss: 0.13261005282402039, Epoch: 5\n", "Loss: 0.1808023750782013, Epoch: 5\n", "Loss: 0.12548168003559113, Epoch: 5\n", "Loss: 0.1334279477596283, Epoch: 5\n", "Loss: 0.13674396276474, Epoch: 5\n", "Loss: 0.13056345283985138, Epoch: 5\n", "Loss: 0.13382840156555176, Epoch: 5\n", "Loss: 0.15729926526546478, Epoch: 5\n", "Loss: 0.12456748634576797, Epoch: 5\n", "Loss: 0.1150851845741272, Epoch: 5\n", "Loss: 0.12835916876792908, Epoch: 5\n", "Loss: 0.17279915511608124, Epoch: 5\n", "Loss: 0.13193294405937195, Epoch: 5\n", "Loss: 0.10962967574596405, Epoch: 5\n", "Loss: 0.12050677835941315, Epoch: 5\n", "Loss: 0.17489789426326752, Epoch: 5\n", "Loss: 0.13095614314079285, Epoch: 5\n", "Loss: 0.12931302189826965, Epoch: 5\n", "Loss: 0.12235209345817566, Epoch: 5\n", "Loss: 0.1342233270406723, Epoch: 5\n", "Loss: 0.17001110315322876, Epoch: 5\n", "Loss: 0.123800128698349, Epoch: 5\n", "Loss: 0.20985817909240723, Epoch: 5\n", "Loss: 0.13027140498161316, Epoch: 5\n", "Loss: 0.11940208822488785, Epoch: 5\n", "Loss: 0.13204459846019745, Epoch: 5\n", "Loss: 0.13626925647258759, Epoch: 5\n", "Loss: 0.12456974387168884, Epoch: 5\n", "Loss: 0.1372535526752472, Epoch: 5\n", "Loss: 0.12558864057064056, Epoch: 5\n", "Loss: 0.14300382137298584, Epoch: 5\n", "Loss: 0.1280129849910736, Epoch: 5\n", "Loss: 0.12656547129154205, Epoch: 5\n", "Loss: 0.10765340924263, Epoch: 5\n", "Loss: 0.10957075655460358, Epoch: 5\n", "Loss: 0.14550140500068665, Epoch: 5\n", "Loss: 0.12373486161231995, Epoch: 5\n", "Loss: 0.11921708285808563, Epoch: 5\n", "Loss: 0.11897507309913635, Epoch: 5\n", "Loss: 0.11595134437084198, Epoch: 5\n", "Loss: 0.15143433213233948, Epoch: 5\n", "Loss: 0.10788632184267044, Epoch: 5\n", "Loss: 0.1180427297949791, Epoch: 5\n", "Loss: 0.12454649806022644, Epoch: 5\n", "Loss: 0.1388833224773407, Epoch: 5\n", "Loss: 0.12321362644433975, Epoch: 5\n", "Loss: 0.1299155354499817, Epoch: 5\n", "Loss: 0.12170170992612839, Epoch: 5\n", "Loss: 0.12160171568393707, Epoch: 5\n", "Loss: 0.11576078832149506, Epoch: 5\n", "Loss: 0.1474548727273941, Epoch: 5\n", "Loss: 0.12996402382850647, Epoch: 5\n", "Loss: 0.11152004450559616, Epoch: 5\n", "Loss: 0.11374114453792572, Epoch: 5\n", "Loss: 0.1124294102191925, Epoch: 5\n", "Loss: 0.11947697401046753, Epoch: 5\n", "Loss: 0.12320362776517868, Epoch: 5\n", "Loss: 0.12594015896320343, Epoch: 5\n", "Loss: 0.1282290816307068, Epoch: 5\n", "Loss: 0.11917021125555038, Epoch: 5\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.11522866785526276, Epoch: 5\n", "Loss: 0.11686207354068756, Epoch: 5\n", "Loss: 0.13003912568092346, Epoch: 5\n", "Loss: 0.10768930613994598, Epoch: 5\n", "Loss: 0.11665098369121552, Epoch: 5\n", "Loss: 0.12027396261692047, Epoch: 5\n", "Loss: 0.1223018616437912, Epoch: 5\n", "Loss: 0.11921297758817673, Epoch: 5\n", "Loss: 0.126190647482872, Epoch: 5\n", "Loss: 0.11158805340528488, Epoch: 5\n", "Loss: 0.12258896976709366, Epoch: 5\n", "Loss: 0.10933675616979599, Epoch: 5\n", "Loss: 0.12323165684938431, Epoch: 5\n", "Loss: 0.09231279790401459, Epoch: 5\n", "Loss: 0.10996890813112259, Epoch: 5\n", "Loss: 0.10218094289302826, Epoch: 5\n", "Loss: 0.11884516477584839, Epoch: 5\n", "Loss: 0.11803596466779709, Epoch: 5\n", "Loss: 0.11715345084667206, Epoch: 5\n", "Loss: 0.10423947125673294, Epoch: 5\n", "Loss: 0.11898159980773926, Epoch: 5\n", "Loss: 0.1188461184501648, Epoch: 5\n", "Loss: 0.12247420847415924, Epoch: 5\n", "Loss: 0.10961997509002686, Epoch: 5\n", "Loss: 0.11624372005462646, Epoch: 6\n", "Loss: 0.10715224593877792, Epoch: 6\n", "Loss: 0.13775856792926788, Epoch: 6\n", "Loss: 0.11116379499435425, Epoch: 6\n", "Loss: 0.13297320902347565, Epoch: 6\n", "Loss: 0.12099210172891617, Epoch: 6\n", "Loss: 0.11810016632080078, Epoch: 6\n", "Loss: 0.12210290879011154, Epoch: 6\n", "Loss: 0.1103905439376831, Epoch: 6\n", "Loss: 0.10696553438901901, Epoch: 6\n", "Loss: 0.11589460074901581, Epoch: 6\n", "Loss: 0.11533137410879135, Epoch: 6\n", "Loss: 0.10772214829921722, Epoch: 6\n", "Loss: 0.12689828872680664, Epoch: 6\n", "Loss: 0.12170947343111038, Epoch: 6\n", "Loss: 0.1141921728849411, Epoch: 6\n", "Loss: 0.11483687907457352, Epoch: 6\n", "Loss: 0.12149053812026978, Epoch: 6\n", "Loss: 0.10325966775417328, Epoch: 6\n", "Loss: 0.11688613146543503, Epoch: 6\n", "Loss: 0.11686090379953384, Epoch: 6\n", "Loss: 0.10853508859872818, Epoch: 6\n", "Loss: 0.12138164043426514, Epoch: 6\n", "Loss: 0.10637712478637695, Epoch: 6\n", "Loss: 0.11518768966197968, Epoch: 6\n", "Loss: 0.1005590632557869, Epoch: 6\n", "Loss: 0.10979194939136505, Epoch: 6\n", "Loss: 0.10970509797334671, Epoch: 6\n", "Loss: 0.10047701746225357, Epoch: 6\n", "Loss: 0.11055599898099899, Epoch: 6\n", "Loss: 0.1114952564239502, Epoch: 6\n", "Loss: 0.1298879086971283, Epoch: 6\n", "Loss: 0.1118747889995575, Epoch: 6\n", "Loss: 0.11944448947906494, Epoch: 6\n", "Loss: 0.11376014351844788, Epoch: 6\n", "Loss: 0.13856884837150574, Epoch: 6\n", "Loss: 0.12177485227584839, Epoch: 6\n", "Loss: 0.09694704413414001, Epoch: 6\n", "Loss: 0.11368582397699356, Epoch: 6\n", "Loss: 0.11257602274417877, Epoch: 6\n", "Loss: 0.1129927709698677, Epoch: 6\n", "Loss: 0.11432743817567825, Epoch: 6\n", "Loss: 0.11266694217920303, Epoch: 6\n", "Loss: 0.10029543936252594, Epoch: 6\n", "Loss: 0.10305653512477875, Epoch: 6\n", "Loss: 0.1084495559334755, Epoch: 6\n", "Loss: 0.1262296736240387, Epoch: 6\n", "Loss: 0.13879963755607605, Epoch: 6\n", "Loss: 0.11159215122461319, Epoch: 6\n", "Loss: 0.10721106827259064, Epoch: 6\n", "Loss: 0.10601997375488281, Epoch: 6\n", "Loss: 0.10585113614797592, Epoch: 6\n", "Loss: 0.12005317956209183, Epoch: 6\n", "Loss: 0.11511770635843277, Epoch: 6\n", "Loss: 0.12122543901205063, Epoch: 6\n", "Loss: 0.0961105078458786, Epoch: 6\n", "Loss: 0.11446699500083923, Epoch: 6\n", "Loss: 0.11680401861667633, Epoch: 6\n", "Loss: 0.11895857751369476, Epoch: 6\n", "Loss: 0.11428070813417435, Epoch: 6\n", "Loss: 0.10440615564584732, Epoch: 6\n", "Loss: 0.10592535138130188, Epoch: 6\n", "Loss: 0.1098354384303093, Epoch: 6\n", "Loss: 0.11369583010673523, Epoch: 6\n", "Loss: 0.11765711009502411, Epoch: 6\n", "Loss: 0.12683521211147308, Epoch: 6\n", "Loss: 0.14331921935081482, Epoch: 6\n", "Loss: 0.12109503895044327, Epoch: 6\n", "Loss: 0.19002178311347961, Epoch: 6\n", "Loss: 0.132974311709404, Epoch: 6\n", "Loss: 0.18188707530498505, Epoch: 6\n", "Loss: 0.1156173050403595, Epoch: 6\n", "Loss: 0.21222077310085297, Epoch: 6\n", "Loss: 0.11632455885410309, Epoch: 6\n", "Loss: 0.11654199659824371, Epoch: 6\n", "Loss: 0.2229778915643692, Epoch: 6\n", "Loss: 0.13579709827899933, Epoch: 6\n", "Loss: 0.13772152364253998, Epoch: 6\n", "Loss: 0.21927736699581146, Epoch: 6\n", "Loss: 0.14530208706855774, Epoch: 6\n", "Loss: 0.12683066725730896, Epoch: 6\n", "Loss: 0.12871938943862915, Epoch: 6\n", "Loss: 0.12640881538391113, Epoch: 6\n", "Loss: 0.12194661796092987, Epoch: 6\n", "Loss: 0.11213715374469757, Epoch: 6\n", "Loss: 0.13071177899837494, Epoch: 6\n", "Loss: 0.14279155433177948, Epoch: 6\n", "Loss: 0.09821438789367676, Epoch: 6\n", "Loss: 0.11000728607177734, Epoch: 6\n", "Loss: 0.13279211521148682, Epoch: 6\n", "Loss: 0.13721898198127747, Epoch: 6\n", "Loss: 0.11367487162351608, Epoch: 6\n", "Loss: 0.13350757956504822, Epoch: 6\n", "Loss: 0.1248038113117218, Epoch: 6\n", "Loss: 0.15275642275810242, Epoch: 6\n", "Loss: 0.1401827335357666, Epoch: 6\n", "Loss: 0.1166832372546196, Epoch: 6\n", "Loss: 0.1256875991821289, Epoch: 6\n", "Loss: 0.11419999599456787, Epoch: 6\n", "Loss: 0.11006093770265579, Epoch: 6\n", "Loss: 0.11871582269668579, Epoch: 6\n", "Loss: 0.11190492659807205, Epoch: 6\n", "Loss: 0.12537062168121338, Epoch: 6\n", "Loss: 0.08388262987136841, Epoch: 6\n", "Loss: 0.1354859173297882, Epoch: 6\n", "Loss: 0.1201799213886261, Epoch: 6\n", "Loss: 0.08500127494335175, Epoch: 6\n", "Loss: 0.09311848878860474, Epoch: 6\n", "Loss: 0.09804543852806091, Epoch: 6\n", "Loss: 0.13395635783672333, Epoch: 6\n", "Loss: 0.11851917952299118, Epoch: 6\n", "Loss: 0.10380667448043823, Epoch: 6\n", "Loss: 0.10534995794296265, Epoch: 6\n", "Loss: 0.10853283107280731, Epoch: 6\n", "Loss: 0.14459960162639618, Epoch: 6\n", "Loss: 0.09134180843830109, Epoch: 6\n", "Loss: 0.10183943808078766, Epoch: 6\n", "Loss: 0.11426414549350739, Epoch: 6\n", "Loss: 0.11680352687835693, Epoch: 6\n", "Loss: 0.1100248321890831, Epoch: 6\n", "Loss: 0.11217734217643738, Epoch: 6\n", "Loss: 0.11318906396627426, Epoch: 6\n", "Loss: 0.11201858520507812, Epoch: 6\n", "Loss: 0.1053985059261322, Epoch: 6\n", "Loss: 0.12794645130634308, Epoch: 6\n", "Loss: 0.11207936704158783, Epoch: 6\n", "Loss: 0.10350680351257324, Epoch: 6\n", "Loss: 0.10383594036102295, Epoch: 6\n", "Loss: 0.10246621817350388, Epoch: 6\n", "Loss: 0.10486087203025818, Epoch: 6\n", "Loss: 0.11843543499708176, Epoch: 6\n", "Loss: 0.10555055737495422, Epoch: 6\n", "Loss: 0.11247757077217102, Epoch: 6\n", "Loss: 0.11798767000436783, Epoch: 6\n", "Loss: 0.10157867521047592, Epoch: 6\n", "Loss: 0.10155266523361206, Epoch: 6\n", "Loss: 0.10669185221195221, Epoch: 6\n", "Loss: 0.09663646668195724, Epoch: 6\n", "Loss: 0.11037393659353256, Epoch: 6\n", "Loss: 0.10974664241075516, Epoch: 6\n", "Loss: 0.10823851823806763, Epoch: 6\n", "Loss: 0.1105186939239502, Epoch: 6\n", "Loss: 0.10878249257802963, Epoch: 6\n", "Loss: 0.0992928072810173, Epoch: 6\n", "Loss: 0.11492595076560974, Epoch: 6\n", "Loss: 0.10114570707082748, Epoch: 6\n", "Loss: 0.11305642127990723, Epoch: 6\n", "Loss: 0.08154793083667755, Epoch: 6\n", "Loss: 0.09805180877447128, Epoch: 6\n", "Loss: 0.08986265957355499, Epoch: 6\n", "Loss: 0.1070898175239563, Epoch: 6\n", "Loss: 0.10565295815467834, Epoch: 6\n", "Loss: 0.10418456047773361, Epoch: 6\n", "Loss: 0.09737840294837952, Epoch: 6\n", "Loss: 0.1021442785859108, Epoch: 6\n", "Loss: 0.10725658386945724, Epoch: 6\n", "Loss: 0.10730786621570587, Epoch: 6\n", "Loss: 0.09617894887924194, Epoch: 6\n", "Loss: 0.11127995699644089, Epoch: 7\n", "Loss: 0.0874585509300232, Epoch: 7\n", "Loss: 0.12249134480953217, Epoch: 7\n", "Loss: 0.1020079031586647, Epoch: 7\n", "Loss: 0.11316166073083878, Epoch: 7\n", "Loss: 0.10774492472410202, Epoch: 7\n", "Loss: 0.09285763651132584, Epoch: 7\n", "Loss: 0.1098235622048378, Epoch: 7\n", "Loss: 0.09434773027896881, Epoch: 7\n", "Loss: 0.09606831520795822, Epoch: 7\n", "Loss: 0.10440029948949814, Epoch: 7\n", "Loss: 0.10212397575378418, Epoch: 7\n", "Loss: 0.09842287749052048, Epoch: 7\n", "Loss: 0.11679971218109131, Epoch: 7\n", "Loss: 0.10654377192258835, Epoch: 7\n", "Loss: 0.10638795793056488, Epoch: 7\n", "Loss: 0.1004040390253067, Epoch: 7\n", "Loss: 0.1108088567852974, Epoch: 7\n", "Loss: 0.09905845671892166, Epoch: 7\n", "Loss: 0.1053299531340599, Epoch: 7\n", "Loss: 0.1049308255314827, Epoch: 7\n", "Loss: 0.09539797902107239, Epoch: 7\n", "Loss: 0.096033975481987, Epoch: 7\n", "Loss: 0.09786777198314667, Epoch: 7\n", "Loss: 0.10437578707933426, Epoch: 7\n", "Loss: 0.08823202550411224, Epoch: 7\n", "Loss: 0.09576096385717392, Epoch: 7\n", "Loss: 0.09627911448478699, Epoch: 7\n", "Loss: 0.0886673852801323, Epoch: 7\n", "Loss: 0.10158203542232513, Epoch: 7\n", "Loss: 0.09936793148517609, Epoch: 7\n", "Loss: 0.10796260088682175, Epoch: 7\n", "Loss: 0.10200068354606628, Epoch: 7\n", "Loss: 0.1036141887307167, Epoch: 7\n", "Loss: 0.10264419764280319, Epoch: 7\n", "Loss: 0.1228143498301506, Epoch: 7\n", "Loss: 0.1109018325805664, Epoch: 7\n", "Loss: 0.08281095325946808, Epoch: 7\n", "Loss: 0.09651509672403336, Epoch: 7\n", "Loss: 0.10483517497777939, Epoch: 7\n", "Loss: 0.10054583847522736, Epoch: 7\n", "Loss: 0.09840130060911179, Epoch: 7\n", "Loss: 0.10182977467775345, Epoch: 7\n", "Loss: 0.09018167108297348, Epoch: 7\n", "Loss: 0.08957422524690628, Epoch: 7\n", "Loss: 0.09855228662490845, Epoch: 7\n", "Loss: 0.11381145566701889, Epoch: 7\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.09318786859512329, Epoch: 7\n", "Loss: 0.09626708924770355, Epoch: 7\n", "Loss: 0.09697861224412918, Epoch: 7\n", "Loss: 0.09386548399925232, Epoch: 7\n", "Loss: 0.09933487325906754, Epoch: 7\n", "Loss: 0.08954917639493942, Epoch: 7\n", "Loss: 0.09967973083257675, Epoch: 7\n", "Loss: 0.10813768208026886, Epoch: 7\n", "Loss: 0.08184698224067688, Epoch: 7\n", "Loss: 0.10852080583572388, Epoch: 7\n", "Loss: 0.10763221979141235, Epoch: 7\n", "Loss: 0.11019491404294968, Epoch: 7\n", "Loss: 0.1040777862071991, Epoch: 7\n", "Loss: 0.0847632959485054, Epoch: 7\n", "Loss: 0.08171959221363068, Epoch: 7\n", "Loss: 0.08776959776878357, Epoch: 7\n", "Loss: 0.10209298133850098, Epoch: 7\n", "Loss: 0.08908098191022873, Epoch: 7\n", "Loss: 0.10392419248819351, Epoch: 7\n", "Loss: 0.09996441006660461, Epoch: 7\n", "Loss: 0.11016766726970673, Epoch: 7\n", "Loss: 0.10478966683149338, Epoch: 7\n", "Loss: 0.07866249978542328, Epoch: 7\n", "Loss: 0.13597208261489868, Epoch: 7\n", "Loss: 0.1079997643828392, Epoch: 7\n", "Loss: 0.08729257434606552, Epoch: 7\n", "Loss: 0.10136373341083527, Epoch: 7\n", "Loss: 0.09196136891841888, Epoch: 7\n", "Loss: 0.1400473266839981, Epoch: 7\n", "Loss: 0.09910447895526886, Epoch: 7\n", "Loss: 0.10202774405479431, Epoch: 7\n", "Loss: 0.0815005972981453, Epoch: 7\n", "Loss: 0.10101136565208435, Epoch: 7\n", "Loss: 0.0984138771891594, Epoch: 7\n", "Loss: 0.09421591460704803, Epoch: 7\n", "Loss: 0.10190987586975098, Epoch: 7\n", "Loss: 0.08138076961040497, Epoch: 7\n", "Loss: 0.08836308121681213, Epoch: 7\n", "Loss: 0.10892144590616226, Epoch: 7\n", "Loss: 0.10090276598930359, Epoch: 7\n", "Loss: 0.07948984205722809, Epoch: 7\n", "Loss: 0.09559652209281921, Epoch: 7\n", "Loss: 0.10386673361063004, Epoch: 7\n", "Loss: 0.10217111557722092, Epoch: 7\n", "Loss: 0.09979840368032455, Epoch: 7\n", "Loss: 0.09837999939918518, Epoch: 7\n", "Loss: 0.10334828495979309, Epoch: 7\n", "Loss: 0.10868719965219498, Epoch: 7\n", "Loss: 0.0996246188879013, Epoch: 7\n", "Loss: 0.09603465348482132, Epoch: 7\n", "Loss: 0.0976654514670372, Epoch: 7\n", "Loss: 0.09486246854066849, Epoch: 7\n", "Loss: 0.08585795760154724, Epoch: 7\n", "Loss: 0.10285493731498718, Epoch: 7\n", "Loss: 0.09559769928455353, Epoch: 7\n", "Loss: 0.10589518398046494, Epoch: 7\n", "Loss: 0.0665736198425293, Epoch: 7\n", "Loss: 0.12515269219875336, Epoch: 7\n", "Loss: 0.10746419429779053, Epoch: 7\n", "Loss: 0.06850691139698029, Epoch: 7\n", "Loss: 0.07568707317113876, Epoch: 7\n", "Loss: 0.08142060041427612, Epoch: 7\n", "Loss: 0.14092078804969788, Epoch: 7\n", "Loss: 0.10802359133958817, Epoch: 7\n", "Loss: 0.09362595528364182, Epoch: 7\n", "Loss: 0.09656472504138947, Epoch: 7\n", "Loss: 0.09624163806438446, Epoch: 7\n", "Loss: 0.1221676841378212, Epoch: 7\n", "Loss: 0.08465592563152313, Epoch: 7\n", "Loss: 0.09251771122217178, Epoch: 7\n", "Loss: 0.1012636050581932, Epoch: 7\n", "Loss: 0.08949314802885056, Epoch: 7\n", "Loss: 0.09692538529634476, Epoch: 7\n", "Loss: 0.10708724707365036, Epoch: 7\n", "Loss: 0.10580960661172867, Epoch: 7\n", "Loss: 0.10133131593465805, Epoch: 7\n", "Loss: 0.09440037608146667, Epoch: 7\n", "Loss: 0.11331550776958466, Epoch: 7\n", "Loss: 0.09957834333181381, Epoch: 7\n", "Loss: 0.08154382556676865, Epoch: 7\n", "Loss: 0.09030596911907196, Epoch: 7\n", "Loss: 0.09509352594614029, Epoch: 7\n", "Loss: 0.09161731600761414, Epoch: 7\n", "Loss: 0.10449625551700592, Epoch: 7\n", "Loss: 0.09905510395765305, Epoch: 7\n", "Loss: 0.10003232210874557, Epoch: 7\n", "Loss: 0.10228797793388367, Epoch: 7\n", "Loss: 0.08921154588460922, Epoch: 7\n", "Loss: 0.09422476589679718, Epoch: 7\n", "Loss: 0.09613534063100815, Epoch: 7\n", "Loss: 0.09139202535152435, Epoch: 7\n", "Loss: 0.09728746861219406, Epoch: 7\n", "Loss: 0.0973246768116951, Epoch: 7\n", "Loss: 0.09548674523830414, Epoch: 7\n", "Loss: 0.10061587393283844, Epoch: 7\n", "Loss: 0.09739266335964203, Epoch: 7\n", "Loss: 0.08884819597005844, Epoch: 7\n", "Loss: 0.09968846291303635, Epoch: 7\n", "Loss: 0.08458088338375092, Epoch: 7\n", "Loss: 0.10188835114240646, Epoch: 7\n", "Loss: 0.0672396719455719, Epoch: 7\n", "Loss: 0.08573003113269806, Epoch: 7\n", "Loss: 0.07837706059217453, Epoch: 7\n", "Loss: 0.10032759606838226, Epoch: 7\n", "Loss: 0.09497064352035522, Epoch: 7\n", "Loss: 0.0931612178683281, Epoch: 7\n", "Loss: 0.0899437889456749, Epoch: 7\n", "Loss: 0.09274802356958389, Epoch: 7\n", "Loss: 0.09666411578655243, Epoch: 7\n", "Loss: 0.09964553266763687, Epoch: 7\n", "Loss: 0.08965051174163818, Epoch: 7\n", "Loss: 0.09531083703041077, Epoch: 8\n", "Loss: 0.0764368548989296, Epoch: 8\n", "Loss: 0.09976893663406372, Epoch: 8\n", "Loss: 0.09121975302696228, Epoch: 8\n", "Loss: 0.09898321330547333, Epoch: 8\n", "Loss: 0.09821636974811554, Epoch: 8\n", "Loss: 0.08225860446691513, Epoch: 8\n", "Loss: 0.10347147285938263, Epoch: 8\n", "Loss: 0.08411657810211182, Epoch: 8\n", "Loss: 0.08513495326042175, Epoch: 8\n", "Loss: 0.09337693452835083, Epoch: 8\n", "Loss: 0.09445364028215408, Epoch: 8\n", "Loss: 0.08690741658210754, Epoch: 8\n", "Loss: 0.10567010939121246, Epoch: 8\n", "Loss: 0.09979277849197388, Epoch: 8\n", "Loss: 0.10393866151571274, Epoch: 8\n", "Loss: 0.08208710700273514, Epoch: 8\n", "Loss: 0.10648505389690399, Epoch: 8\n", "Loss: 0.09276064485311508, Epoch: 8\n", "Loss: 0.10089381039142609, Epoch: 8\n", "Loss: 0.0965060368180275, Epoch: 8\n", "Loss: 0.0881606936454773, Epoch: 8\n", "Loss: 0.08871510624885559, Epoch: 8\n", "Loss: 0.09277530014514923, Epoch: 8\n", "Loss: 0.09860555082559586, Epoch: 8\n", "Loss: 0.07966770976781845, Epoch: 8\n", "Loss: 0.09030231833457947, Epoch: 8\n", "Loss: 0.0851370170712471, Epoch: 8\n", "Loss: 0.08056827634572983, Epoch: 8\n", "Loss: 0.09375553578138351, Epoch: 8\n", "Loss: 0.0939326360821724, Epoch: 8\n", "Loss: 0.09433168917894363, Epoch: 8\n", "Loss: 0.09604447335004807, Epoch: 8\n", "Loss: 0.09176722913980484, Epoch: 8\n", "Loss: 0.09819795936346054, Epoch: 8\n", "Loss: 0.10539628565311432, Epoch: 8\n", "Loss: 0.10229707509279251, Epoch: 8\n", "Loss: 0.07657080888748169, Epoch: 8\n", "Loss: 0.09068753570318222, Epoch: 8\n", "Loss: 0.0944778323173523, Epoch: 8\n", "Loss: 0.09684833884239197, Epoch: 8\n", "Loss: 0.09293913841247559, Epoch: 8\n", "Loss: 0.10148358345031738, Epoch: 8\n", "Loss: 0.08338888734579086, Epoch: 8\n", "Loss: 0.08214575797319412, Epoch: 8\n", "Loss: 0.08863586187362671, Epoch: 8\n", "Loss: 0.10475130379199982, Epoch: 8\n", "Loss: 0.10403673350811005, Epoch: 8\n", "Loss: 0.08990421146154404, Epoch: 8\n", "Loss: 0.08858269453048706, Epoch: 8\n", "Loss: 0.08760341256856918, Epoch: 8\n", "Loss: 0.09343436360359192, Epoch: 8\n", "Loss: 0.0796404555439949, Epoch: 8\n", "Loss: 0.09368477016687393, Epoch: 8\n", "Loss: 0.0998200923204422, Epoch: 8\n", "Loss: 0.07597582042217255, Epoch: 8\n", "Loss: 0.10604184120893478, Epoch: 8\n", "Loss: 0.09754043072462082, Epoch: 8\n", "Loss: 0.10065465420484543, Epoch: 8\n", "Loss: 0.10647395253181458, Epoch: 8\n", "Loss: 0.08068972080945969, Epoch: 8\n", "Loss: 0.16115640103816986, Epoch: 8\n", "Loss: 0.11444346606731415, Epoch: 8\n", "Loss: 0.10705339163541794, Epoch: 8\n", "Loss: 0.09086104482412338, Epoch: 8\n", "Loss: 0.10477293282747269, Epoch: 8\n", "Loss: 0.09490729123353958, Epoch: 8\n", "Loss: 0.11211740970611572, Epoch: 8\n", "Loss: 0.10389804095029831, Epoch: 8\n", "Loss: 0.08844345808029175, Epoch: 8\n", "Loss: 0.21717138588428497, Epoch: 8\n", "Loss: 0.10999349504709244, Epoch: 8\n", "Loss: 0.08422992378473282, Epoch: 8\n", "Loss: 0.101418137550354, Epoch: 8\n", "Loss: 0.09684649854898453, Epoch: 8\n", "Loss: 0.1724022775888443, Epoch: 8\n", "Loss: 0.1314866691827774, Epoch: 8\n", "Loss: 0.09903855621814728, Epoch: 8\n", "Loss: 0.11843590438365936, Epoch: 8\n", "Loss: 0.12226806581020355, Epoch: 8\n", "Loss: 0.13927863538265228, Epoch: 8\n", "Loss: 0.10693665593862534, Epoch: 8\n", "Loss: 0.10284698754549026, Epoch: 8\n", "Loss: 0.09007545560598373, Epoch: 8\n", "Loss: 0.13318724930286407, Epoch: 8\n", "Loss: 0.11283812671899796, Epoch: 8\n", "Loss: 0.13335789740085602, Epoch: 8\n", "Loss: 0.08382488787174225, Epoch: 8\n", "Loss: 0.09779892861843109, Epoch: 8\n", "Loss: 0.11922663450241089, Epoch: 8\n", "Loss: 0.10391080379486084, Epoch: 8\n", "Loss: 0.09241219609975815, Epoch: 8\n", "Loss: 0.0914870947599411, Epoch: 8\n", "Loss: 0.10276245325803757, Epoch: 8\n", "Loss: 0.11850879341363907, Epoch: 8\n", "Loss: 0.11146795749664307, Epoch: 8\n", "Loss: 0.13590680062770844, Epoch: 8\n", "Loss: 0.10979834944009781, Epoch: 8\n", "Loss: 0.09215784817934036, Epoch: 8\n", "Loss: 0.10343850404024124, Epoch: 8\n", "Loss: 0.10027733445167542, Epoch: 8\n", "Loss: 0.09378498047590256, Epoch: 8\n", "Loss: 0.11168068647384644, Epoch: 8\n", "Loss: 0.0735282376408577, Epoch: 8\n", "Loss: 0.11101522296667099, Epoch: 8\n", "Loss: 0.09046553075313568, Epoch: 8\n", "Loss: 0.049216534942388535, Epoch: 8\n", "Loss: 0.07616625726222992, Epoch: 8\n", "Loss: 0.08197750896215439, Epoch: 8\n", "Loss: 0.11903025954961777, Epoch: 8\n", "Loss: 0.11152675747871399, Epoch: 8\n", "Loss: 0.09254954010248184, Epoch: 8\n", "Loss: 0.10088139027357101, Epoch: 8\n", "Loss: 0.09094111621379852, Epoch: 8\n", "Loss: 0.11987733095884323, Epoch: 8\n", "Loss: 0.07824293524026871, Epoch: 8\n", "Loss: 0.08746776729822159, Epoch: 8\n", "Loss: 0.09556399285793304, Epoch: 8\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loss: 0.10032295435667038, Epoch: 8\n", "Loss: 0.0936238169670105, Epoch: 8\n", "Loss: 0.09704777598381042, Epoch: 8\n", "Loss: 0.0960117056965828, Epoch: 8\n", "Loss: 0.09698610007762909, Epoch: 8\n", "Loss: 0.08561122417449951, Epoch: 8\n", "Loss: 0.11067008972167969, Epoch: 8\n", "Loss: 0.09560175240039825, Epoch: 8\n", "Loss: 0.07673484086990356, Epoch: 8\n", "Loss: 0.08287633210420609, Epoch: 8\n", "Loss: 0.0930718407034874, Epoch: 8\n", "Loss: 0.08970781415700912, Epoch: 8\n", "Loss: 0.1007469967007637, Epoch: 8\n", "Loss: 0.09398171305656433, Epoch: 8\n", "Loss: 0.09208820015192032, Epoch: 8\n", "Loss: 0.09928972274065018, Epoch: 8\n", "Loss: 0.08620091527700424, Epoch: 8\n", "Loss: 0.08966758847236633, Epoch: 8\n", "Loss: 0.09183480590581894, Epoch: 8\n", "Loss: 0.08743667602539062, Epoch: 8\n", "Loss: 0.08666228502988815, Epoch: 8\n", "Loss: 0.09077548235654831, Epoch: 8\n", "Loss: 0.09201730787754059, Epoch: 8\n", "Loss: 0.09336574375629425, Epoch: 8\n", "Loss: 0.0882939025759697, Epoch: 8\n", "Loss: 0.08134134858846664, Epoch: 8\n", "Loss: 0.09802527725696564, Epoch: 8\n", "Loss: 0.073699451982975, Epoch: 8\n", "Loss: 0.1019870713353157, Epoch: 8\n", "Loss: 0.06060896813869476, Epoch: 8\n", "Loss: 0.08239734917879105, Epoch: 8\n", "Loss: 0.0757080689072609, Epoch: 8\n", "Loss: 0.09910555928945541, Epoch: 8\n", "Loss: 0.09245341271162033, Epoch: 8\n", "Loss: 0.08715236186981201, Epoch: 8\n", "Loss: 0.08551301062107086, Epoch: 8\n", "Loss: 0.08673674613237381, Epoch: 8\n", "Loss: 0.09225232154130936, Epoch: 8\n", "Loss: 0.09661070257425308, Epoch: 8\n", "Loss: 0.0856349915266037, Epoch: 8\n", "Loss: 0.09376442432403564, Epoch: 9\n", "Loss: 0.074952132999897, Epoch: 9\n", "Loss: 0.0973762795329094, Epoch: 9\n", "Loss: 0.08727239817380905, Epoch: 9\n", "Loss: 0.09685800969600677, Epoch: 9\n", "Loss: 0.0946836918592453, Epoch: 9\n", "Loss: 0.07464674860239029, Epoch: 9\n", "Loss: 0.1091916412115097, Epoch: 9\n", "Loss: 0.08540654927492142, Epoch: 9\n", "Loss: 0.07873443514108658, Epoch: 9\n", "Loss: 0.09215348213911057, Epoch: 9\n", "Loss: 0.09173493087291718, Epoch: 9\n", "Loss: 0.0822596400976181, Epoch: 9\n", "Loss: 0.10001017898321152, Epoch: 9\n", "Loss: 0.09149155765771866, Epoch: 9\n", "Loss: 0.09980180859565735, Epoch: 9\n", "Loss: 0.07469983398914337, Epoch: 9\n", "Loss: 0.10162779688835144, Epoch: 9\n", "Loss: 0.08965972810983658, Epoch: 9\n", "Loss: 0.09272082149982452, Epoch: 9\n", "Loss: 0.09031746536493301, Epoch: 9\n", "Loss: 0.08489258587360382, Epoch: 9\n", "Loss: 0.08522868901491165, Epoch: 9\n", "Loss: 0.09051043540239334, Epoch: 9\n", "Loss: 0.09652537107467651, Epoch: 9\n", "Loss: 0.07605654001235962, Epoch: 9\n", "Loss: 0.08628328889608383, Epoch: 9\n", "Loss: 0.08136139810085297, Epoch: 9\n", "Loss: 0.07680477946996689, Epoch: 9\n", "Loss: 0.08961768448352814, Epoch: 9\n", "Loss: 0.08857917040586472, Epoch: 9\n", "Loss: 0.08441700041294098, Epoch: 9\n", "Loss: 0.09196633845567703, Epoch: 9\n", "Loss: 0.08878505229949951, Epoch: 9\n", "Loss: 0.09737300127744675, Epoch: 9\n", "Loss: 0.10665343701839447, Epoch: 9\n", "Loss: 0.0984821617603302, Epoch: 9\n", "Loss: 0.07351763546466827, Epoch: 9\n", "Loss: 0.08747286349534988, Epoch: 9\n", "Loss: 0.08971679955720901, Epoch: 9\n", "Loss: 0.08926985412836075, Epoch: 9\n", "Loss: 0.08825411647558212, Epoch: 9\n", "Loss: 0.09101610630750656, Epoch: 9\n", "Loss: 0.07893285900354385, Epoch: 9\n", "Loss: 0.07997577637434006, Epoch: 9\n", "Loss: 0.08593187481164932, Epoch: 9\n", "Loss: 0.09999746084213257, Epoch: 9\n", "Loss: 0.0822121724486351, Epoch: 9\n", "Loss: 0.08718079328536987, Epoch: 9\n", "Loss: 0.08673427253961563, Epoch: 9\n", "Loss: 0.08414871990680695, Epoch: 9\n", "Loss: 0.08811739832162857, Epoch: 9\n", "Loss: 0.0737643912434578, Epoch: 9\n", "Loss: 0.09219112247228622, Epoch: 9\n", "Loss: 0.09499291330575943, Epoch: 9\n", "Loss: 0.07242050766944885, Epoch: 9\n", "Loss: 0.10049676150083542, Epoch: 9\n", "Loss: 0.08951854705810547, Epoch: 9\n", "Loss: 0.0999721810221672, Epoch: 9\n", "Loss: 0.09247854351997375, Epoch: 9\n", "Loss: 0.07728950679302216, Epoch: 9\n", "Loss: 0.06448691338300705, Epoch: 9\n", "Loss: 0.07841381430625916, Epoch: 9\n", "Loss: 0.09069116413593292, Epoch: 9\n", "Loss: 0.07123135775327682, Epoch: 9\n", "Loss: 0.09597131609916687, Epoch: 9\n", "Loss: 0.09440863877534866, Epoch: 9\n", "Loss: 0.10218434035778046, Epoch: 9\n", "Loss: 0.09929265826940536, Epoch: 9\n", "Loss: 0.07066919654607773, Epoch: 9\n", "Loss: 0.11114039272069931, Epoch: 9\n", "Loss: 0.1049976646900177, Epoch: 9\n", "Loss: 0.07026403397321701, Epoch: 9\n", "Loss: 0.09360364079475403, Epoch: 9\n", "Loss: 0.08021758496761322, Epoch: 9\n", "Loss: 0.13227015733718872, Epoch: 9\n", "Loss: 0.0857333168387413, Epoch: 9\n", "Loss: 0.08923828601837158, Epoch: 9\n", "Loss: 0.06564822793006897, Epoch: 9\n", "Loss: 0.08624057471752167, Epoch: 9\n", "Loss: 0.0866137221455574, Epoch: 9\n", "Loss: 0.06691738963127136, Epoch: 9\n", "Loss: 0.0908471941947937, Epoch: 9\n", "Loss: 0.07081125676631927, Epoch: 9\n", "Loss: 0.07842528074979782, Epoch: 9\n", "Loss: 0.11552178114652634, Epoch: 9\n", "Loss: 0.09867746382951736, Epoch: 9\n", "Loss: 0.07390105724334717, Epoch: 9\n", "Loss: 0.08417465537786484, Epoch: 9\n", "Loss: 0.0967276319861412, Epoch: 9\n", "Loss: 0.0920858383178711, Epoch: 9\n", "Loss: 0.07572290301322937, Epoch: 9\n", "Loss: 0.08656008541584015, Epoch: 9\n", "Loss: 0.0918826162815094, Epoch: 9\n", "Loss: 0.10554063320159912, Epoch: 9\n", "Loss: 0.08934647589921951, Epoch: 9\n", "Loss: 0.09057360887527466, Epoch: 9\n", "Loss: 0.09137235581874847, Epoch: 9\n", "Loss: 0.08764832466840744, Epoch: 9\n", "Loss: 0.07917024940252304, Epoch: 9\n", "Loss: 0.0912671834230423, Epoch: 9\n", "Loss: 0.08849883079528809, Epoch: 9\n", "Loss: 0.09417363256216049, Epoch: 9\n", "Loss: 0.04771730676293373, Epoch: 9\n", "Loss: 0.10762224346399307, Epoch: 9\n", "Loss: 0.08276030421257019, Epoch: 9\n", "Loss: 0.052407264709472656, Epoch: 9\n", "Loss: 0.06693300604820251, Epoch: 9\n", "Loss: 0.07219705730676651, Epoch: 9\n", "Loss: 0.10493184626102448, Epoch: 9\n", "Loss: 0.1038055345416069, Epoch: 9\n", "Loss: 0.08958692103624344, Epoch: 9\n", "Loss: 0.08443622291088104, Epoch: 9\n", "Loss: 0.0833052545785904, Epoch: 9\n", "Loss: 0.10586792975664139, Epoch: 9\n", "Loss: 0.0817166417837143, Epoch: 9\n", "Loss: 0.08074997365474701, Epoch: 9\n", "Loss: 0.09111910313367844, Epoch: 9\n", "Loss: 0.07266872376203537, Epoch: 9\n", "Loss: 0.09336880594491959, Epoch: 9\n", "Loss: 0.08866416662931442, Epoch: 9\n", "Loss: 0.08915765583515167, Epoch: 9\n", "Loss: 0.09368506819009781, Epoch: 9\n", "Loss: 0.08114416897296906, Epoch: 9\n", "Loss: 0.10218866169452667, Epoch: 9\n", "Loss: 0.09324317425489426, Epoch: 9\n", "Loss: 0.06947305798530579, Epoch: 9\n", "Loss: 0.07948879152536392, Epoch: 9\n", "Loss: 0.08767526596784592, Epoch: 9\n", "Loss: 0.08044221997261047, Epoch: 9\n", "Loss: 0.09291483461856842, Epoch: 9\n", "Loss: 0.09230954945087433, Epoch: 9\n", "Loss: 0.08549445867538452, Epoch: 9\n", "Loss: 0.09530370682477951, Epoch: 9\n", "Loss: 0.08001953363418579, Epoch: 9\n", "Loss: 0.08556060492992401, Epoch: 9\n", "Loss: 0.08849553018808365, Epoch: 9\n", "Loss: 0.08635181933641434, Epoch: 9\n", "Loss: 0.07598816603422165, Epoch: 9\n", "Loss: 0.08768333494663239, Epoch: 9\n", "Loss: 0.08585868775844574, Epoch: 9\n", "Loss: 0.08838070183992386, Epoch: 9\n", "Loss: 0.08511821925640106, Epoch: 9\n", "Loss: 0.07946973294019699, Epoch: 9\n", "Loss: 0.09176114201545715, Epoch: 9\n", "Loss: 0.06730469316244125, Epoch: 9\n", "Loss: 0.0944196805357933, Epoch: 9\n", "Loss: 0.05201512947678566, Epoch: 9\n", "Loss: 0.07947589457035065, Epoch: 9\n", "Loss: 0.07191232591867447, Epoch: 9\n", "Loss: 0.09604261815547943, Epoch: 9\n", "Loss: 0.09151257574558258, Epoch: 9\n", "Loss: 0.08036286383867264, Epoch: 9\n", "Loss: 0.08668318390846252, Epoch: 9\n", "Loss: 0.08424527198076248, Epoch: 9\n", "Loss: 0.08812244981527328, Epoch: 9\n", "Loss: 0.08832524716854095, Epoch: 9\n", "Loss: 0.08474713563919067, Epoch: 9\n" ] } ], "source": [ "for epoch in range(10):\n", " net.zero_grad()\n", " for j, data in enumerate(dataset):\n", " image, context, prediction = data\n", " image = image.unsqueeze(0).cuda()\n", " context = context.unsqueeze(0).cuda()\n", " prediction = prediction.cuda()\n", " output = net(image, context)\n", " output = output.squeeze(0)\n", " prediction = torch.argmax(prediction, 1)\n", " loss = criterion(output, prediction)\n", " loss.backward()\n", " if j%10 == 0:\n", " optimizer.step()\n", " print('Loss: {}, Epoch: {}'.format(loss.data, epoch))\n", " net.zero_grad()\n", "\n", "torch.save(net.state_dict(), './pix2code.weights')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Testing" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Pix2Code(\n", " (image_encoder): ImageEncoder(\n", " (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))\n", " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))\n", " (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n", " (conv4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))\n", " (conv5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))\n", " (conv6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))\n", " (fc1): Linear(in_features=100352, out_features=1024, bias=True)\n", " (fc2): Linear(in_features=1024, out_features=1024, bias=True)\n", " )\n", " (context_encoder): ContextEncoder(\n", " (rnn): RNN(19, 128, num_layers=2, batch_first=True)\n", " )\n", " (decoder): Decoder(\n", " (rnn): RNN(1152, 512, num_layers=2, batch_first=True)\n", " (l1): Linear(in_features=512, out_features=19, bias=True)\n", " )\n", ")" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net = Pix2Code()\n", "net.load_state_dict(torch.load('./pix2code.weights'))\n", "net.cuda().eval()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "test_data = UIDataset('./dataset/evaluation', 'voc.pkl')\n", "vocab = Vocabulary('voc.pkl')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image, *_ = test_data.__getitem__(np.random.randint(len(test_data)))\n", "t = transforms.ToPILImage()\n", "image = image.unsqueeze(0)\n", "t(image.squeeze())" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "header{\n", "btn-inactive,btn-active\n", "}\n", "row{\n", "single{\n", "small-title,text,btn-green\n", "}\n", "}\n", "row{\n", "double{\n", "small-title,text,btn-green\n", "}\n", "double{\n", "small-title,text,btn-green\n", "}\n", "}\n", "row{\n", "quadruple{\n", "small-title,text,btn-green\n", "}\n", "quadruple{\n", "small-title,text,btn-green\n", "}\n", "quadruple{\n", "small-title,text,btn-green\n", "}\n", "quadruple{\n", "small-title,text,btn-green\n", "}\n", "}\n", "\n" ] } ], "source": [ "image = image.cuda()\n", "ct = []\n", "ct.append(vocab.to_vec(' '))\n", "ct.append(vocab.to_vec(''))\n", "output = ''\n", "for i in range(200):\n", " context = torch.tensor(ct).unsqueeze(0).float().cuda()\n", " index = torch.argmax(net(image, context), 2).squeeze()[-1:].squeeze()\n", " v = vocab.to_vocab(int(index))\n", " if v == '':\n", " break\n", " output += v\n", " ct.append(vocab.to_vec(v))\n", "\n", "with open('./compiler/output.gui', 'w') as f:\n", " f.write(output)\n", "\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now from the compiler directory in your terminal run\n", "`python web-compiler.py output.gui`.\n", "This will generate a `output.html` file that you can open in your browser." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.2 64-bit ('pytorch': conda)", "language": "python", "name": "python38264bitpytorchcondaf04cb2303bb94659b54446e023c3cb62" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 4 }