K00B404 commited on
Commit
91e19b1
1 Parent(s): 48978a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -27
app.py CHANGED
@@ -60,29 +60,94 @@ class UNetWrapper:
60
  def __init__(self, unet_model, repo_id):
61
  self.model = unet_model
62
  self.repo_id = repo_id
63
- self.token = os.getenv('HF_WRITE')
 
64
 
65
  def push_to_hub(self):
66
- # Initialize the Hugging Face API
67
- api = HfApi()
68
- # Create a repository if it doesn't exist
69
- create_repo(self.repo_id, exist_ok=True,token=self.token)
70
- pth_name = 'model_weights.pth'
71
- torch.save(self.model.state_dict(), pth_name)
72
- from huggingface_hub import upload_file
73
-
74
- upload_file(
75
- path_or_fileobj=pth_name,
76
- path_in_repo=pth_name,
77
- repo_id=self.repo_id,
78
- token=self.token,
79
- repo_type="model"
80
-
81
- )
82
- #api.upload_file(repo_id=self.repo_id, path_in_repo=pth_name, path_or_fileobj=pth_name)
83
- # Push the model's state dict to the Hugging Face Hub
84
- #self.model.save_pretrained(self.repo_id,token=self.token) # You may need to implement this method
85
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # Training function
88
  def train_model(epochs):
@@ -134,12 +199,9 @@ def train_model(epochs):
134
  return model
135
 
136
  # Push model to Hugging Face Hub
137
- def push_model_to_hub(model, repo_name):
138
- # Usage example
139
- model_wrapper = UNetWrapper(model, model_repo_id)
140
- model_wrapper.push_to_hub()
141
- # Push the model to the Hugging Face hub
142
- #model.push_to_hub(repo_name)
143
 
144
  # Gradio interface function
145
  def gradio_train(epochs):
 
60
  def __init__(self, unet_model, repo_id):
61
  self.model = unet_model
62
  self.repo_id = repo_id
63
+ self.token = os.getenv('HF_WRITE') # Make sure this environment variable is set
64
+ self.api = HfApi()
65
 
66
  def push_to_hub(self):
67
+ try:
68
+ # Save model state and configuration
69
+ save_dict = {
70
+ 'model_state_dict': self.model.state_dict(),
71
+ 'model_config': {
72
+ 'big': isinstance(self.model, big_UNet),
73
+ 'img_size': 1024 if isinstance(self.model, big_UNet) else 256
74
+ },
75
+ 'model_architecture': str(self.model)
76
+ }
77
+
78
+ # Save model locally
79
+ pth_name = 'model_weights.pth'
80
+ torch.save(save_dict, pth_name)
81
+
82
+ # Create repo if it doesn't exist
83
+ try:
84
+ create_repo(
85
+ repo_id=self.repo_id,
86
+ token=self.token,
87
+ exist_ok=True
88
+ )
89
+ except Exception as e:
90
+ print(f"Repository creation note: {e}")
91
+
92
+ # Upload the model file
93
+ self.api.upload_file(
94
+ path_or_fileobj=pth_name,
95
+ path_in_repo=pth_name,
96
+ repo_id=self.repo_id,
97
+ token=self.token,
98
+ repo_type="model"
99
+ )
100
+
101
+ # Create and upload model card
102
+ model_card = f"""---
103
+ tags:
104
+ - unet
105
+ - pix2pix
106
+ library_name: pytorch
107
+ ---
108
+
109
+ # Pix2Pix UNet Model
110
+
111
+ ## Model Description
112
+ Custom UNet model for Pix2Pix image translation.
113
+ - Image Size: {1024 if isinstance(self.model, big_UNet) else 256}
114
+ - Model Type: {"Big (1024)" if isinstance(self.model, big_UNet) else "Small (256)"}
115
+
116
+ ## Usage
117
+
118
+ ```python
119
+ import torch
120
+ from small_256_model import UNet as small_UNet
121
+ from big_1024_model import UNet as big_UNet
122
+
123
+ # Load the model
124
+ checkpoint = torch.load('model_weights.pth')
125
+ model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
126
+ model.load_state_dict(checkpoint['model_state_dict'])
127
+ model.eval()
128
+ Model Architecture
129
+ {str(self.model)}
130
+ """
131
+ # Save and upload README
132
+ with open("README.md", "w") as f:
133
+ f.write(model_card)
134
+
135
+ self.api.upload_file(
136
+ path_or_fileobj="README.md",
137
+ path_in_repo="README.md",
138
+ repo_id=self.repo_id,
139
+ token=self.token,
140
+ repo_type="model"
141
+ )
142
+
143
+ # Clean up local files
144
+ os.remove(pth_name)
145
+ os.remove("README.md")
146
+
147
+ print(f"Model successfully uploaded to {self.repo_id}")
148
+
149
+ except Exception as e:
150
+ print(f"Error uploading model: {e}")
151
 
152
  # Training function
153
  def train_model(epochs):
 
199
  return model
200
 
201
  # Push model to Hugging Face Hub
202
+ def push_model_to_hub(model, repo_id):
203
+ wrapper = UNetWrapper(model, repo_id)
204
+ wrapper.push_to_hub()
 
 
 
205
 
206
  # Gradio interface function
207
  def gradio_train(epochs):