File size: 1,516 Bytes
3542be4
 
 
 
 
 
 
 
 
8264f82
3542be4
 
 
 
 
8264f82
3542be4
3f5a0c2
3542be4
 
088a973
3542be4
088a973
3542be4
088a973
3542be4
 
 
3f5a0c2
3542be4
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os

import gradio as gr

from utils.prompt_utils import remove_color
from utils.tagger import modelLoad, analysis


class PromptAnalysis:
    def __init__(self, model_dir, post_filter=True,
                 default_nagative_prompt="lowres, error, extra digit, fewer digits, cropped, worst quality, "
                                         "low quality, normal quality, jpeg artifacts, blurry"):
        self.default_nagative_prompt = default_nagative_prompt
        self.post_filter = post_filter
        self.model = None
        self.model_dir = model_dir 

    def layout(self, input_image_path):
        with gr.Column():
            with gr.Row():
                self.prompt = gr.Textbox(label="prompt", lines=3)
            with gr.Row():
                self.negative_prompt = gr.Textbox(label="negative_prompt", lines=3, value=self.default_nagative_prompt)
            with gr.Row():
                self.prompt_analysis_button = gr.Button()

        self.prompt_analysis_button.click(
            self.process_prompt_analysis,
            inputs=[input_image_path],
            outputs=self.prompt
        )
        return [self.prompt, self.negative_prompt]

    def process_prompt_analysis(self, input_image_path):
        if self.model is None:
            self.model = modelLoad(self.model_dir)
        tags = analysis(input_image_path, self.model_dir, self.model)
        tags_list = tags      
        if self.post_filter:
            tags_list = remove_color(tags)
        return tags_list