Spaces:
Runtime error
Runtime error
import paddle | |
import paddle.nn as nn | |
import math | |
class Painter(nn.Layer): | |
""" | |
network architecture written in paddle. | |
""" | |
def __init__(self, param_per_stroke, total_strokes, hidden_dim, n_heads=8, n_enc_layers=3, n_dec_layers=3): | |
super().__init__() | |
self.enc_img = nn.Sequential( | |
nn.Pad2D([1, 1, 1, 1], 'reflect'), | |
nn.Conv2D(3, 32, 3, 1), | |
nn.BatchNorm2D(32), | |
nn.ReLU(), # maybe replace with the inplace version | |
nn.Pad2D([1, 1, 1, 1], 'reflect'), | |
nn.Conv2D(32, 64, 3, 2), | |
nn.BatchNorm2D(64), | |
nn.ReLU(), | |
nn.Pad2D([1, 1, 1, 1], 'reflect'), | |
nn.Conv2D(64, 128, 3, 2), | |
nn.BatchNorm2D(128), | |
nn.ReLU()) | |
self.enc_canvas = nn.Sequential( | |
nn.Pad2D([1, 1, 1, 1], 'reflect'), | |
nn.Conv2D(3, 32, 3, 1), | |
nn.BatchNorm2D(32), | |
nn.ReLU(), | |
nn.Pad2D([1, 1, 1, 1], 'reflect'), | |
nn.Conv2D(32, 64, 3, 2), | |
nn.BatchNorm2D(64), | |
nn.ReLU(), | |
nn.Pad2D([1, 1, 1, 1], 'reflect'), | |
nn.Conv2D(64, 128, 3, 2), | |
nn.BatchNorm2D(128), | |
nn.ReLU()) | |
self.conv = nn.Conv2D(128 * 2, hidden_dim, 1) | |
self.transformer = nn.Transformer(hidden_dim, n_heads, n_enc_layers, n_dec_layers) | |
self.linear_param = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, param_per_stroke)) | |
self.linear_decider = nn.Linear(hidden_dim, 1) | |
self.query_pos = paddle.static.create_parameter([total_strokes, hidden_dim], dtype='float32', | |
default_initializer=nn.initializer.Uniform(0, 1)) | |
self.row_embed = paddle.static.create_parameter([8, hidden_dim // 2], dtype='float32', | |
default_initializer=nn.initializer.Uniform(0, 1)) | |
self.col_embed = paddle.static.create_parameter([8, hidden_dim // 2], dtype='float32', | |
default_initializer=nn.initializer.Uniform(0, 1)) | |
def forward(self, img, canvas): | |
""" | |
prediction | |
""" | |
b, _, H, W = img.shape | |
img_feat = self.enc_img(img) | |
canvas_feat = self.enc_canvas(canvas) | |
h, w = img_feat.shape[-2:] | |
feat = paddle.concat([img_feat, canvas_feat], axis=1) | |
feat_conv = self.conv(feat) | |
pos_embed = paddle.concat([ | |
self.col_embed[:w].unsqueeze(0).tile([h, 1, 1]), | |
self.row_embed[:h].unsqueeze(1).tile([1, w, 1]), | |
], axis=-1).flatten(0, 1).unsqueeze(1) | |
hidden_state = self.transformer((pos_embed + feat_conv.flatten(2).transpose([2, 0, 1])).transpose([1, 0, 2]), | |
self.query_pos.unsqueeze(1).tile([1, b, 1]).transpose([1, 0, 2])) | |
param = self.linear_param(hidden_state) | |
decision = self.linear_decider(hidden_state) | |
return param, decision |