clement-bonnet commited on
Commit
f6ee8cd
1 Parent(s): 4bab54a

fix: x y corrdinates

Browse files
Files changed (2) hide show
  1. app.py +26 -18
  2. inference.py +2 -0
app.py CHANGED
@@ -22,18 +22,26 @@ def process_coord_click(image_idx: int, evt: gr.SelectData) -> Image.Image:
22
  Process the click event on the coordinate selector
23
  """
24
  x, y = evt.index[0], evt.index[1]
25
- x, y = x / 400, y / 400
26
- print(f"Clicked at coordinates: ({x:.3f}, {y:.3f})")
27
  return generate_image(image_idx, x, y)
28
 
29
 
30
  with gr.Blocks(
31
  css="""
32
- .radio-container {
33
  width: 450px !important;
34
  margin-left: auto !important;
35
  margin-right: auto !important;
36
  }
 
 
 
 
 
 
 
 
 
37
  """
38
  ) as demo:
39
  gr.Markdown(
@@ -44,7 +52,7 @@ with gr.Blocks(
44
  )
45
 
46
  with gr.Row():
47
- # Left column: Radio selection and reference image
48
  with gr.Column(scale=1):
49
  # State variable to track selected image index
50
  selected_idx = gr.State(value=0)
@@ -58,7 +66,7 @@ with gr.Blocks(
58
  interactive=True,
59
  )
60
 
61
- # Single reference image component that updates based on selection
62
  reference_image = gr.Image(
63
  value="imgs/pattern_0.png",
64
  show_label=False,
@@ -67,20 +75,20 @@ with gr.Blocks(
67
  width=450,
68
  )
69
 
70
- # Right column: Coordinate selector and output image
71
- with gr.Column(scale=1):
72
- # Coordinate selector with dynamic background
73
- coord_selector = gr.Image(
74
- value="imgs/heatmap_0.png",
75
- label="Click to select (x, y) coordinates in the latent space",
76
- show_label=True,
77
- interactive=True,
78
- height=400,
79
- width=400,
80
- )
81
 
82
- # Generated image output
83
- output_image = gr.Image(label="Generated Output", height=400, width=400)
 
 
 
 
 
 
 
 
 
84
 
85
  # Handle radio button selection
86
  task_select.change(
 
22
  Process the click event on the coordinate selector
23
  """
24
  x, y = evt.index[0], evt.index[1]
25
+ x, y = x / 1155, y / 1155 # Normalize the coordinates
 
26
  return generate_image(image_idx, x, y)
27
 
28
 
29
  with gr.Blocks(
30
  css="""
31
+ .radio-container {
32
  width: 450px !important;
33
  margin-left: auto !important;
34
  margin-right: auto !important;
35
  }
36
+ .coordinate-container {
37
+ width: 600px !important;
38
+ height: 600px !important;
39
+ }
40
+ .coordinate-container img {
41
+ width: 100% !important;
42
+ height: 100% !important;
43
+ object-fit: contain !important;
44
+ }
45
  """
46
  ) as demo:
47
  gr.Markdown(
 
52
  )
53
 
54
  with gr.Row():
55
+ # Left column: Radio selection, reference image, and output
56
  with gr.Column(scale=1):
57
  # State variable to track selected image index
58
  selected_idx = gr.State(value=0)
 
66
  interactive=True,
67
  )
68
 
69
+ # Reference image component that updates based on selection
70
  reference_image = gr.Image(
71
  value="imgs/pattern_0.png",
72
  show_label=False,
 
75
  width=450,
76
  )
77
 
78
+ # Generated image output moved below reference image
79
+ output_image = gr.Image(label="Generated Output", height=300, width=450)
 
 
 
 
 
 
 
 
 
80
 
81
+ # Right column: Larger coordinate selector
82
+ with gr.Column(scale=1):
83
+ # Coordinate selector with container class for proper scaling
84
+ with gr.Column(elem_classes="coordinate-container"):
85
+ coord_selector = gr.Image(
86
+ value="imgs/heatmap_0.png",
87
+ label="Click to select (x, y) coordinates in the latent space",
88
+ show_label=True,
89
+ interactive=True,
90
+ container=True,
91
+ )
92
 
93
  # Handle radio button selection
94
  task_select.change(
inference.py CHANGED
@@ -82,6 +82,8 @@ generate_output_from_context = jax.jit(
82
  def generate_image(image_idx: int, x: float, y: float, eps: float = 1e-4) -> Image.Image:
83
  # Create the input image
84
  input = jnp.zeros(16, int).at[BLUE_LOCATION_INPUTS[image_idx]].set(1).reshape(4, 4)
 
 
85
  # Ensure x and y are in [eps, 1 - eps]
86
  x = min(1 - eps, max(eps, x))
87
  y = min(1 - eps, max(eps, y))
 
82
  def generate_image(image_idx: int, x: float, y: float, eps: float = 1e-4) -> Image.Image:
83
  # Create the input image
84
  input = jnp.zeros(16, int).at[BLUE_LOCATION_INPUTS[image_idx]].set(1).reshape(4, 4)
85
+ # Inverse the y coordinate
86
+ y = 1 - y
87
  # Ensure x and y are in [eps, 1 - eps]
88
  x = min(1 - eps, max(eps, x))
89
  y = min(1 - eps, max(eps, y))