ljt019 commited on
Commit
fb141c8
1 Parent(s): 31f46ed

dev: update model card

Browse files
Files changed (1) hide show
  1. README.md +76 -25
README.md CHANGED
@@ -1,25 +1,76 @@
1
- # Cat and Dog Sketch Classifier
2
-
3
- This project contains a machine learning model that differentiates between sketches of cats and dogs. It was created as a learning exercise to understand how AI models work and how to train them.
4
-
5
- ## Project Structure
6
-
7
- - `quickdraw_data/` - Directory containing `cat.npy` and `dog.npy` files.
8
- - `cat_dog_classifier.bin` - Trained model file.
9
- - `config.json` - Configuration file with parameters for the model and training.
10
- - `model.py` - Contains the model definition for the Convolutional Neural Network (CNN).
11
- - `sample_predictions.png` - Image file with sample predictions from the model.
12
- - `train_cat_dog_classifier.py` - Script to train the classifier.
13
- - `requirements.txt` - Python Dependencies needed to run the training script
14
-
15
- ## Dependencies
16
-
17
- To install the required Python packages to train the model, use the following command:
18
-
19
- ```bash
20
- pip install -r requirements.txt
21
- ```
22
-
23
- ## License
24
-
25
- MIT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Cat and Dog Sketch Classifier
3
+ emoji: 🐱🐶
4
+ tags:
5
+ - image-classification
6
+ - quickdraw
7
+ - cat
8
+ - dog
9
+ license: mit
10
+ ---
11
+
12
+ # Cat and Dog Sketch Classifier
13
+
14
+ This is a machine learning model trained to differentiate between sketches of cats and dogs. It was built as part of a learning project to understand how AI models work and how to train them.
15
+
16
+ ## Model Details
17
+
18
+ - **Model Type**: Convolutional Neural Network (CNN)
19
+ - **Training Data**: Quick, Draw! dataset (cat and dog sketches)
20
+ - **License**: MIT License
21
+ - **Supported Tasks**: Image Classification
22
+
23
+ ## Usage
24
+
25
+ To use this model, you can follow these steps:
26
+
27
+ 1. **Load the Model**:
28
+ ```python
29
+ import torch
30
+ from model import SimpleCNN
31
+ model = SimpleCNN()
32
+ model.load_state_dict(torch.load('cat_dog_classifier.bin'))
33
+ model.eval()
34
+ ```
35
+
36
+ 2. **Predict an Image**:
37
+ ```python
38
+ from PIL import Image
39
+ import numpy as np
40
+ import torch
41
+
42
+ def predict_image(model, image):
43
+ # Preprocess the image
44
+ if isinstance(image, Image.Image):
45
+ image = image.resize((28, 28)).convert('L')
46
+ image = np.array(image).astype('float32') / 255.0
47
+ elif isinstance(image, np.ndarray):
48
+ if image.shape != (28, 28):
49
+ image = Image.fromarray(image).resize((28, 28)).convert('L')
50
+ image = np.array(image).astype('float32') / 255.0
51
+ else:
52
+ raise ValueError("Image must be a PIL Image or NumPy array.")
53
+
54
+ image = image.reshape(1, 1, 28, 28)
55
+ image_tensor = torch.tensor(image).to(device)
56
+
57
+ # Get prediction
58
+ model.eval()
59
+ with torch.no_grad():
60
+ output = model(image_tensor)
61
+ _, predicted = torch.max(output.data, 1)
62
+ return 'cat' if predicted.item() == 0 else 'dog'
63
+
64
+ # Example usage
65
+ image = Image.open('path/to/your/image.png')
66
+ prediction = predict_image(model, image)
67
+ print(prediction)
68
+ ```
69
+
70
+ ## Training the Model
71
+
72
+ To train the model yourself, use the provided `train_cat_dog_classifier.py` script.
73
+
74
+ ## License
75
+
76
+ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for more details.