Questions about finetuning the model
It's from my mistake. (Solved)
Solved how?
The JAX-CV codebase has got the code needed to reset the head weights in the wd_taggers_v3_finetune
branch, the only missing part are a few lines to load msgpack serialized weights instead of an orbax checkpoint.
First of all, thank you for your reply.
I was having trouble loading the correct msgpack into orbax as you said, which I have now solved by pulling the code from wdv3-jax.
I don't know much about deep learning, but since I was just adding a few tags (resetting unused tags for me to be precise), I didn't reset the head weight.
It seems to work fine, although the MCC is a bit lower than expected. I'm going to lower the LR and try again.
I've left the code below for the next person to fine-tune it.
....<omitted>
print(run_name)
weights_path = "./model.msgpack"
with open(weights_path, "rb") as f:
data = f.read()
restored = flax.serialization.msgpack_restore(data)["model"]
variables = {"params": restored["params"], **restored["constants"]}
state = state.replace(params=restored["params"])
if restore_params_ckpt or restore_simmim_ckpt: #actually not needed anyway
....<omitted>
Thank you.
I've committed official support for msgpack files in the wd_taggers_v3_finetune
branch.