timm
/

Image Classification
timm
PyTorch
Safetensors

Style transfer

#5
by make-gyver - opened

I'm trying a style transfer by naively copying from example code using Ghostnet pulled from timm. MobileNetV4 does not work though. Is there some obvious mistake?

class GhostStyleModel(nn.Module):
def __init__(self):
    super().__init__()
    model = timm.create_model('ghostnet_100', pretrained=True)
    blocks = list(model.children())[:-4]
    self.net = nn.ModuleList(blocks)

def forward(self, x: torch.Tensor):
    emb = []
    for block in self.net:
        x = block(x)
        emb.append(x)
    return emb


class MobStyleV4Model(nn.Module):
def __init__(self):
    super().__init__()
    model = timm.create_model('mobilenetv4_conv_small.e2400_r224_in1k', pretrained=True)

    blocks = list(model.children())[:-5]
    self.net = nn.ModuleList(blocks)

def forward(self, x: torch.Tensor):
    emb = []
    for block in self.net:
        x = block(x)
        emb.append(x)
    return emb
PyTorch Image Models org

@make-gyver that's a very model specific way of extracting features and doesn't entirely make sense to me in terms of the features being extracted.

There is documentation about different model agnostic ways of extracting commonly used features -- https://huggingface.co/docs/timm/en/feature_extraction

If you want to do it the way it's being done for that ghostnet model you'll have to manually pick through the modules and maybe select specific equivalent ones to append to emb

I started an experiment while awaiting your response, and interestingly enough, it kinda works with
mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k

This is supposed to be the tadj mahal in the style of Van Goghs starry night.
zwischenbild1000.jpg

This is the repo I forked from.

Sign up or log in to comment