hysts HF staff commited on
Commit
002f9ff
1 Parent(s): cae9bf7
Files changed (1) hide show
  1. model.py +8 -11
model.py CHANGED
@@ -1,4 +1,3 @@
1
- import gc
2
  import tempfile
3
 
4
  import numpy as np
@@ -70,17 +69,15 @@ class Model:
70
  'cuda' if torch.cuda.is_available() else 'cpu')
71
  self.xm = load_model('transmitter', device=self.device)
72
  self.diffusion = diffusion_from_config(load_config('diffusion'))
73
- self.model_name = ''
74
- self.model = None
75
 
76
  def load_model(self, model_name: str) -> None:
77
  assert model_name in ['text300M', 'image300M']
78
- if model_name == self.model_name:
79
- return
80
- self.model = load_model(model_name, device=self.device)
81
- self.model_name = model_name
82
- gc.collect()
83
- torch.cuda.empty_cache()
84
 
85
  def to_glb(self, latent: torch.Tensor) -> str:
86
  ply_path = tempfile.NamedTemporaryFile(suffix='.ply',
@@ -109,7 +106,7 @@ class Model:
109
 
110
  latents = sample_latents(
111
  batch_size=1,
112
- model=self.model,
113
  diffusion=self.diffusion,
114
  guidance_scale=guidance_scale,
115
  model_kwargs=dict(texts=[prompt]),
@@ -135,7 +132,7 @@ class Model:
135
  image = load_image(image_path)
136
  latents = sample_latents(
137
  batch_size=1,
138
- model=self.model,
139
  diffusion=self.diffusion,
140
  guidance_scale=guidance_scale,
141
  model_kwargs=dict(images=[image]),
 
 
1
  import tempfile
2
 
3
  import numpy as np
 
69
  'cuda' if torch.cuda.is_available() else 'cpu')
70
  self.xm = load_model('transmitter', device=self.device)
71
  self.diffusion = diffusion_from_config(load_config('diffusion'))
72
+ self.model_text = None
73
+ self.model_image = None
74
 
75
  def load_model(self, model_name: str) -> None:
76
  assert model_name in ['text300M', 'image300M']
77
+ if model_name == 'text300M' and self.model_text is None:
78
+ self.model_text = load_model(model_name, device=self.device)
79
+ elif model_name == 'image300M' and self.model_image is None:
80
+ self.model_image = load_model(model_name, device=self.device)
 
 
81
 
82
  def to_glb(self, latent: torch.Tensor) -> str:
83
  ply_path = tempfile.NamedTemporaryFile(suffix='.ply',
 
106
 
107
  latents = sample_latents(
108
  batch_size=1,
109
+ model=self.model_text,
110
  diffusion=self.diffusion,
111
  guidance_scale=guidance_scale,
112
  model_kwargs=dict(texts=[prompt]),
 
132
  image = load_image(image_path)
133
  latents = sample_latents(
134
  batch_size=1,
135
+ model=self.model_image,
136
  diffusion=self.diffusion,
137
  guidance_scale=guidance_scale,
138
  model_kwargs=dict(images=[image]),