Question about why it's compatible

#1
by feifeiobama - opened

Hi, I am a co-author of Pyramid Flow, and hope to address the PyTorch version restriction as soon as possible. I just saw you comment

Change this line:

self.timesteps_per_stage[i_s] = torch.from_numpy(timesteps[:-1])

To this:

self.timesteps_per_stage[i_s] = timesteps[:-1]

This will allow the model to be compatible with newer versions of pytorch and other libraries than is shown in the requirements.

I wonder what the mechanism behind this is, and whether we should remove all torch.from_numpy like this.

It seems to be from here:

timesteps = np.linspace(
                timestep_max, timestep_min, training_steps + 1,
            )

timestep_max and timestep_min here are singleton tensors, and apparently np.linspace, when used with tensors as the endpoints, will return a tensor. I don't know which version of numpy this was added in.

I have one environment with numpy 1.26.4 and it doesn't happen there.

In version 2.0.2, though, it does. For example, this:

a = torch.Tensor([0])
b = torch.Tensor([10])
c = np.linspace(a, b, 11)
print(type(c))

Returns <class 'torch.Tensor'>.

It seems to be specific to this use of torch.from_numpy because the inputs happened to already be tensors.

Maybe this is a specific issue with a newer version of numpy, or really not an issue but a convenience function to keep things as tensors that wasn't expected to happen here.

Here is an alternative which should work for both versions:

timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)].item()
timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)].item()

This will convert the singleton tensor into a number, which then will behave normally with np.linspace.

Not sure about other incompatibilities, but the rest of the code works with torch 2.4.1+cu124 other than this small change. Tested with just installing that and the requirements file with == changed to >= for each requirement.

Thanks for the explanation, I will try it and see if it works in my environment.

Looks like PR #39 in the inference repo resolved it in another alternative way :)

Sign up or log in to comment