Jiahuita commited on
Commit
1d3834b
1 Parent(s): ade1685

Add model files and API implementation

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  combined_data.csv filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  combined_data.csv filter=lfs diff=lfs merge=lfs -text
37
+ *.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,74 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ tags:
4
+ - text-classification
5
+ - tensorflow
6
+ - news-classification
7
+ pipeline_tag: text-classification
8
+ widget:
9
+ - text: "Enter your news headline here"
10
+ datasets:
11
+ - custom_news_dataset
12
+ model-index:
13
+ - name: news-source-classifier
14
+ results:
15
+ - task:
16
+ type: text-classification
17
+ name: News Source Classification
18
+ metrics:
19
+ - type: accuracy
20
+ value: 0.82
21
+ name: Test Accuracy
22
  ---
23
+
24
+ # News Source Classifier
25
+
26
+ This model classifies news headlines as either Fox News or NBC News using an LSTM neural network.
27
+
28
+ ## Model Description
29
+
30
+ - **Model Architecture**: LSTM Neural Network
31
+ - **Input**: News headlines (text)
32
+ - **Output**: Binary classification (Fox News vs NBC)
33
+ - **Training Data**: Large collection of headlines from both news sources
34
+ - **Performance**: Achieves approximately 82% accuracy on the test set
35
+
36
+ ## Usage
37
+
38
+ You can use this model directly with a FastAPI endpoint:
39
+
40
+ ```python
41
+ import requests
42
+
43
+ # Make a prediction
44
+ response = requests.post(
45
+ "https://your-app-url/predict",
46
+ json={"text": "Your news headline here"}
47
+ )
48
+ print(response.json())
49
+ ```
50
+
51
+ Or use it locally:
52
+
53
+ ```python
54
+ from transformers import pipeline
55
+
56
+ classifier = pipeline("text-classification", model="your-username/news-source-classifier")
57
+ result = classifier("Your news headline here")
58
+ print(result)
59
+ ```
60
+
61
+ ## Limitations and Bias
62
+
63
+ This model has been trained on news headlines from specific sources and time periods, which may introduce certain biases. Users should be aware of these limitations when using the model.
64
+
65
+ ## Training
66
+
67
+ The model was trained using:
68
+ - TensorFlow 2.13.0
69
+ - LSTM architecture
70
+ - Binary cross-entropy loss
71
+ - Adam optimizer
72
+
73
+ ## License
74
+ This project is licensed under the MIT License.
__pycache__/app.cpython-39.pyc ADDED
Binary file (2.23 kB). View file
 
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ from tensorflow.keras.models import load_model
4
+ from tensorflow.keras.preprocessing.text import tokenizer_from_json
5
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
6
+ import json
7
+ import pickle
8
+ from fastapi import FastAPI, HTTPException
9
+ from pydantic import BaseModel
10
+
11
+ app = FastAPI(title="News Source Classifier")
12
+
13
+ try:
14
+ model = load_model('news_classifier.h5')
15
+
16
+ with open('tokenizer.json') as f:
17
+ tokenizer_data = json.load(f)
18
+ tokenizer = tokenizer_from_json(tokenizer_data)
19
+
20
+ with open('vectorizer.pkl', 'rb') as f:
21
+ vectorizer = pickle.load(f)
22
+ except Exception as e:
23
+ print(f"Error loading model: {str(e)}")
24
+ raise
25
+
26
+ class PredictionRequest(BaseModel):
27
+ text: str
28
+
29
+ class PredictionResponse(BaseModel):
30
+ source: str
31
+ confidence: float
32
+
33
+ @app.post("/predict", response_model=PredictionResponse)
34
+ async def predict(request: PredictionRequest):
35
+ try:
36
+ sequence = tokenizer.texts_to_sequences([request.text])
37
+ padded = pad_sequences(sequence, maxlen=100)
38
+
39
+ prediction = model.predict(padded)
40
+ confidence = float(np.max(prediction))
41
+
42
+ predicted_class = int(np.argmax(prediction))
43
+ source = 'foxnews' if predicted_class == 0 else 'nbc'
44
+
45
+ return PredictionResponse(
46
+ source=source,
47
+ confidence=confidence
48
+ )
49
+ except Exception as e:
50
+ raise HTTPException(status_code=500, detail=str(e))
51
+
52
+ @app.get("/")
53
+ async def root():
54
+ return {
55
+ "message": "News Source Classifier API",
56
+ "usage": "Make a POST request to /predict with a JSON payload containing 'text' field"
57
+ }
news_classifier.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9258ee4d92199555974374b569634e73ad0d2b059d3b7125f3b75c2144528f4
3
+ size 117315152
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ tensorflow==2.13.0
2
+ fastapi==0.68.1
3
+ uvicorn==0.15.0
4
+ numpy>=1.19.2
5
+ pydantic==1.8.2
6
+ python-multipart==0.0.5
7
+ scikit-learn>=0.24.2
8
+ joblib>=1.1.0
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f470ba41cfc5aead8a050a1fdb44b35191b4c72de7703db579a668867363cf24
3
+ size 7021963
vectorizer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c379aaf596d44439b4330c4f89e3813e734dea7867b4b5fd9065c547161e552
3
+ size 900222