TuringsSolutions
commited on
Commit
•
d6424f1
1
Parent(s):
ae98e87
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,10 @@ import numpy as np
|
|
4 |
from PIL import Image
|
5 |
from tqdm import tqdm
|
6 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# Disable GPU usage by default
|
9 |
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
@@ -32,10 +36,10 @@ class SwarmAgent:
|
|
32 |
self.velocity = velocity
|
33 |
|
34 |
class SwarmNeuralNetwork:
|
35 |
-
def __init__(self, num_agents, image_shape,
|
36 |
self.image_shape = image_shape
|
37 |
self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)]
|
38 |
-
self.target_image =
|
39 |
|
40 |
def random_position(self):
|
41 |
return np.random.randn(*self.image_shape)
|
@@ -43,10 +47,6 @@ class SwarmNeuralNetwork:
|
|
43 |
def random_velocity(self):
|
44 |
return np.random.randn(*self.image_shape) * 0.01
|
45 |
|
46 |
-
def load_target_image(self, img_path):
|
47 |
-
img = Image.open(img_path).convert('RGB').resize((self.image_shape[1], self.image_shape[0]))
|
48 |
-
return np.array(img) / 127.5 - 1
|
49 |
-
|
50 |
def update_agents(self, timestep):
|
51 |
for agent in self.agents:
|
52 |
# Convert agent's position and target image into HDC space
|
@@ -74,18 +74,42 @@ class SwarmNeuralNetwork:
|
|
74 |
|
75 |
return self.generate_image()
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
interface = gr.Interface(
|
84 |
fn=train_snn,
|
85 |
inputs=[
|
86 |
-
gr.Image(type="
|
87 |
-
gr.Slider(minimum=100, maximum=1000, value=500, label="Number of Agents"),
|
88 |
-
gr.Slider(minimum=5, maximum=20, value=10, label="Number of Epochs")
|
89 |
],
|
90 |
outputs=gr.Image(type="numpy", label="Generated Image"),
|
91 |
title="HDC Swarm Neural Network Image Generation"
|
|
|
4 |
from PIL import Image
|
5 |
from tqdm import tqdm
|
6 |
import matplotlib.pyplot as plt
|
7 |
+
import logging
|
8 |
+
|
9 |
+
# Set up logging
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
|
12 |
# Disable GPU usage by default
|
13 |
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
|
|
36 |
self.velocity = velocity
|
37 |
|
38 |
class SwarmNeuralNetwork:
|
39 |
+
def __init__(self, num_agents, image_shape, target_image):
|
40 |
self.image_shape = image_shape
|
41 |
self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)]
|
42 |
+
self.target_image = target_image
|
43 |
|
44 |
def random_position(self):
|
45 |
return np.random.randn(*self.image_shape)
|
|
|
47 |
def random_velocity(self):
|
48 |
return np.random.randn(*self.image_shape) * 0.01
|
49 |
|
|
|
|
|
|
|
|
|
50 |
def update_agents(self, timestep):
|
51 |
for agent in self.agents:
|
52 |
# Convert agent's position and target image into HDC space
|
|
|
74 |
|
75 |
return self.generate_image()
|
76 |
|
77 |
+
def preprocess_image(image):
|
78 |
+
"""Preprocess the input image."""
|
79 |
+
if image is None:
|
80 |
+
raise ValueError("No image provided")
|
81 |
+
|
82 |
+
if isinstance(image, np.ndarray):
|
83 |
+
# If it's already a numpy array, just resize and normalize
|
84 |
+
image = Image.fromarray(image)
|
85 |
+
elif isinstance(image, str):
|
86 |
+
# If it's a file path, open the image
|
87 |
+
image = Image.open(image)
|
88 |
+
else:
|
89 |
+
raise ValueError("Unsupported image type")
|
90 |
+
|
91 |
+
image = image.convert('RGB').resize((128, 128))
|
92 |
+
return np.array(image) / 127.5 - 1
|
93 |
+
|
94 |
+
def train_snn(image, num_agents, epochs):
|
95 |
+
try:
|
96 |
+
logging.info(f"Received image type: {type(image)}")
|
97 |
+
preprocessed_image = preprocess_image(image)
|
98 |
+
logging.info(f"Preprocessed image shape: {preprocessed_image.shape}")
|
99 |
+
|
100 |
+
snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(128, 128, 3), target_image=preprocessed_image)
|
101 |
+
generated_image = snn.train(epochs=epochs)
|
102 |
+
return (generated_image * 255).astype(np.uint8)
|
103 |
+
except Exception as e:
|
104 |
+
logging.error(f"Error in train_snn: {str(e)}")
|
105 |
+
return None
|
106 |
|
107 |
interface = gr.Interface(
|
108 |
fn=train_snn,
|
109 |
inputs=[
|
110 |
+
gr.Image(type="numpy", label="Upload Target Image"),
|
111 |
+
gr.Slider(minimum=100, maximum=1000, value=500, step=50, label="Number of Agents"),
|
112 |
+
gr.Slider(minimum=5, maximum=20, value=10, step=1, label="Number of Epochs")
|
113 |
],
|
114 |
outputs=gr.Image(type="numpy", label="Generated Image"),
|
115 |
title="HDC Swarm Neural Network Image Generation"
|