TuringsSolutions commited on
Commit
7d43cc6
1 Parent(s): 1f219e4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
4
+ from keras.models import Model
5
+ import matplotlib.pyplot as plt
6
+ import logging
7
+ from skimage.transform import resize
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+
11
+ class SwarmAgent:
12
+ def __init__(self, position, velocity):
13
+ self.position = position
14
+ self.velocity = velocity
15
+ self.m = np.zeros_like(position)
16
+ self.v = np.zeros_like(position)
17
+
18
+ class SwarmNeuralNetwork:
19
+ def __init__(self, num_agents, image_shape, target_image):
20
+ self.image_shape = image_shape
21
+ self.resized_shape = (64, 64, 3)
22
+ self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)]
23
+ self.target_image = self.load_target_image(target_image)
24
+ self.generated_image = np.random.randn(*image_shape) # Start with noise
25
+ self.mobilenet = self.load_mobilenet_model()
26
+ self.current_epoch = 0
27
+ self.noise_schedule = np.linspace(0.1, 0.002, 1000) # Noise schedule
28
+
29
+ def random_position(self):
30
+ return np.random.randn(*self.image_shape) # Use Gaussian noise
31
+
32
+ def random_velocity(self):
33
+ return np.random.randn(*self.image_shape) * 0.01
34
+
35
+ def load_target_image(self, img):
36
+ img = img.resize((self.image_shape[1], self.image_shape[0]))
37
+ img_array = np.array(img) / 127.5 - 1 # Normalize to [-1, 1]
38
+ plt.imshow((img_array + 1) / 2) # Convert back to [0, 1] for display
39
+ plt.title('Target Image')
40
+ plt.show()
41
+ return img_array
42
+
43
+ def resize_image(self, image):
44
+ return resize(image, self.resized_shape, anti_aliasing=True)
45
+
46
+ def load_mobilenet_model(self):
47
+ mobilenet = MobileNetV2(weights='imagenet', include_top=False, input_shape=self.resized_shape)
48
+ return Model(inputs=mobilenet.input, outputs=mobilenet.get_layer('block_13_expand_relu').output)
49
+
50
+ def add_positional_encoding(self, image):
51
+ h, w, c = image.shape
52
+ pos_enc = np.zeros_like(image)
53
+ for i in range(h):
54
+ for j in range(w):
55
+ pos_enc[i, j, :] = [i/h, j/w, 0]
56
+ return image + pos_enc
57
+
58
+ def multi_head_attention(self, agent, num_heads=4):
59
+ attention_scores = []
60
+ for _ in range(num_heads):
61
+ similarity = np.exp(-np.sum((agent.position - self.target_image)**2, axis=-1))
62
+ attention_score = similarity / np.sum(similarity)
63
+ attention_scores.append(attention_score)
64
+ attention = np.mean(attention_scores, axis=0)
65
+ return np.expand_dims(attention, axis=-1)
66
+
67
+ def multi_scale_perceptual_loss(self, agent_positions):
68
+ target_image_resized = self.resize_image((self.target_image + 1) / 2) # Convert to [0, 1] for MobileNet
69
+ target_image_preprocessed = preprocess_input(target_image_resized[np.newaxis, ...] * 255) # MobileNet expects [0, 255]
70
+ target_features = self.mobilenet.predict(target_image_preprocessed)
71
+
72
+ losses = []
73
+ for agent_position in agent_positions:
74
+ agent_image_resized = self.resize_image((agent_position + 1) / 2)
75
+ agent_image_preprocessed = preprocess_input(agent_image_resized[np.newaxis, ...] * 255)
76
+ agent_features = self.mobilenet.predict(agent_image_preprocessed)
77
+
78
+ loss = np.mean((target_features - agent_features)**2)
79
+ losses.append(1 / (1 + loss))
80
+
81
+ return np.array(losses)
82
+
83
+ def update_agents(self, timestep):
84
+ noise_level = self.noise_schedule[min(timestep, len(self.noise_schedule) - 1)]
85
+
86
+ for agent in self.agents:
87
+ # Predict noise
88
+ predicted_noise = agent.position - self.target_image
89
+
90
+ # Denoise
91
+ denoised = (agent.position - noise_level * predicted_noise) / (1 - noise_level)
92
+
93
+ # Add scaled noise for next step
94
+ agent.position = denoised + np.random.randn(*self.image_shape) * np.sqrt(noise_level)
95
+
96
+ # Clip values
97
+ agent.position = np.clip(agent.position, -1, 1)
98
+
99
+ def generate_image(self):
100
+ self.generated_image = np.mean([agent.position for agent in self.agents], axis=0)
101
+ # Normalize to [0, 1] range for display
102
+ self.generated_image = (self.generated_image + 1) / 2
103
+ self.generated_image = np.clip(self.generated_image, 0, 1)
104
+
105
+ def train(self, epochs):
106
+ logging.basicConfig(filename='training.log', level=logging.INFO)
107
+
108
+ for epoch in tqdm(range(epochs), desc="Training Epochs"):
109
+ self.update_agents(epoch)
110
+ self.generate_image()
111
+
112
+ mse = np.mean(((self.generated_image * 2 - 1) - self.target_image)**2)
113
+ logging.info(f"Epoch {epoch}, MSE: {mse}")
114
+
115
+ if epoch % 10 == 0:
116
+ print(f"Epoch {epoch}, MSE: {mse}")
117
+ self.display_image(self.generated_image, title=f'Epoch {epoch}')
118
+ self.current_epoch += 1
119
+
120
+ def display_image(self, image, title=''):
121
+ plt.imshow(image)
122
+ plt.title(title)
123
+ plt.axis('off')
124
+ plt.show()
125
+
126
+ def display_agent_positions(self, epoch):
127
+ fig, ax = plt.subplots()
128
+ positions = np.array([agent.position for agent in self.agents])
129
+ ax.imshow(self.generated_image, extent=[0, self.image_shape[1], 0, self.image_shape[0]])
130
+ ax.scatter(positions[:, :, 0].flatten(), positions[:, :, 1].flatten(), s=1, c='red')
131
+ plt.title(f'Agent Positions at Epoch {epoch}')
132
+ plt.show()
133
+
134
+ def save_model(self, filename):
135
+ model_state = {
136
+ 'agents': self.agents,
137
+ 'generated_image': self.generated_image,
138
+ 'current_epoch': self.current_epoch
139
+ }
140
+ np.save(filename, model_state)
141
+
142
+ def load_model(self, filename):
143
+ model_state = np.load(filename, allow_pickle=True).item()
144
+ self.agents = model_state['agents']
145
+ self.generated_image = model_state['generated_image']
146
+ self.current_epoch = model_state['current_epoch']
147
+
148
+ def generate_new_image(self, num_steps=1000):
149
+ for agent in self.agents:
150
+ agent.position = np.random.randn(*self.image_shape)
151
+
152
+ for step in tqdm(range(num_steps), desc="Generating Image"):
153
+ self.update_agents(num_steps - step - 1) # Reverse order
154
+
155
+ self.generate_image()
156
+ return self.generated_image
157
+
158
+ # Gradio Interface
159
+ def train_snn(image, num_agents, epochs):
160
+ snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(64, 64, 3), target_image=image)
161
+ snn.train(epochs=epochs)
162
+ snn.save_model('snn_model.npy')
163
+ return snn.generated_image
164
+
165
+ def generate_new_image():
166
+ snn = SwarmNeuralNetwork(num_agents=2000, image_shape=(64, 64, 3), target_image=None)
167
+ snn.load_model('snn_model.npy')
168
+ new_image = snn.generate_new_image()
169
+ return new_image
170
+
171
+ interface = gr.Interface(
172
+ fn=train_snn,
173
+ inputs=[
174
+ gr.inputs.Image(type="pil", label="Upload Target Image"),
175
+ gr.inputs.Slider(minimum=500, maximum=3000, default=2000, label="Number of Agents"),
176
+ gr.inputs.Slider(minimum=10, maximum=200, default=100, label="Number of Epochs")
177
+ ],
178
+ outputs=gr.outputs.Image(type="numpy", label="Generated Image"),
179
+ title="Swarm Neural Network Image Generation",
180
+ description="Upload an image and set the number of agents and epochs to train the Swarm Neural Network to generate a new image."
181
+ )
182
+
183
+ interface.launch()