changing custom pipeline and pinning requirements
Browse files- MyConfig.py +0 -1
- MyPipe.py +11 -14
- README.md +3 -10
- briarmbg.py +0 -1
- requirements.txt +1 -1
MyConfig.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
from transformers import PretrainedConfig
|
3 |
from typing import List
|
4 |
|
|
|
|
|
1 |
from transformers import PretrainedConfig
|
2 |
from typing import List
|
3 |
|
MyPipe.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import torch, os
|
3 |
import torch.nn.functional as F
|
4 |
from torchvision.transforms.functional import normalize
|
@@ -20,8 +19,8 @@ class RMBGPipe(Pipeline):
|
|
20 |
postprocess_kwargs = {}
|
21 |
if "model_input_size" in kwargs :
|
22 |
preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
|
23 |
-
if "
|
24 |
-
postprocess_kwargs["
|
25 |
return preprocess_kwargs, {}, postprocess_kwargs
|
26 |
|
27 |
def preprocess(self,im_path:str,model_input_size: list=[1024,1024]):
|
@@ -40,21 +39,19 @@ class RMBGPipe(Pipeline):
|
|
40 |
result = self.model(inputs.pop("image"))
|
41 |
inputs["result"] = result
|
42 |
return inputs
|
43 |
-
def postprocess(self,inputs,
|
44 |
result = inputs.pop("result")
|
45 |
orig_im_size = inputs.pop("orig_im_size")
|
46 |
im_path = inputs.pop("im_path")
|
47 |
result_image = self.postprocess_image(result[0][0], orig_im_size)
|
48 |
-
|
49 |
-
|
50 |
-
pil_im
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
return result_image
|
57 |
-
|
58 |
# utilities functions
|
59 |
def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
|
60 |
# same as utilities.py with minor modification
|
|
|
|
|
1 |
import torch, os
|
2 |
import torch.nn.functional as F
|
3 |
from torchvision.transforms.functional import normalize
|
|
|
19 |
postprocess_kwargs = {}
|
20 |
if "model_input_size" in kwargs :
|
21 |
preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
|
22 |
+
if "return_mask" in kwargs:
|
23 |
+
postprocess_kwargs["return_mask"] = kwargs["return_mask"]
|
24 |
return preprocess_kwargs, {}, postprocess_kwargs
|
25 |
|
26 |
def preprocess(self,im_path:str,model_input_size: list=[1024,1024]):
|
|
|
39 |
result = self.model(inputs.pop("image"))
|
40 |
inputs["result"] = result
|
41 |
return inputs
|
42 |
+
def postprocess(self,inputs,return_mask:bool=False ):
|
43 |
result = inputs.pop("result")
|
44 |
orig_im_size = inputs.pop("orig_im_size")
|
45 |
im_path = inputs.pop("im_path")
|
46 |
result_image = self.postprocess_image(result[0][0], orig_im_size)
|
47 |
+
pil_im = Image.fromarray(result_image)
|
48 |
+
if return_mask ==True :
|
49 |
+
return pil_im
|
50 |
+
no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
|
51 |
+
orig_image = Image.open(im_path)
|
52 |
+
no_bg_image.paste(orig_image, mask=pil_im)
|
53 |
+
return no_bg_image
|
54 |
+
|
|
|
|
|
55 |
# utilities functions
|
56 |
def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
|
57 |
# same as utilities.py with minor modification
|
README.md
CHANGED
@@ -110,13 +110,6 @@ or load the pipeline
|
|
110 |
```python
|
111 |
from transformers import pipeline
|
112 |
pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
|
113 |
-
|
114 |
-
pipe("image_path"
|
115 |
-
```
|
116 |
-
|
117 |
-
# parameters :
|
118 |
-
for the pipeline you can use the following parameters :
|
119 |
-
* `model_input_size` : default to [1024,1024]
|
120 |
-
* `out_name` : if specified it will use the numpy mask to extract the image and save it using the `out_name`
|
121 |
-
* `preprocess_image` : method for preprocessing images
|
122 |
-
* `postprocess_image` : method for postprocessing images
|
|
|
110 |
```python
|
111 |
from transformers import pipeline
|
112 |
pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
|
113 |
+
pillow_mask = pipe("img_path",return_mask = True) # outputs a pillow mask
|
114 |
+
pillow_image = pipe("image_path") # applies mask on input and returns a pillow image
|
115 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
briarmbg.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
requirements.txt
CHANGED
@@ -5,4 +5,4 @@ numpy
|
|
5 |
typing
|
6 |
scikit-image
|
7 |
huggingface_hub
|
8 |
-
|
|
|
5 |
typing
|
6 |
scikit-image
|
7 |
huggingface_hub
|
8 |
+
transformers==4.39.1
|