TirthGPT commited on
Commit
802ec79
1 Parent(s): 6b16dc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -45
app.py CHANGED
@@ -9,6 +9,8 @@ from torchvision.models import resnet50
9
 
10
  # Initialize inference client for chat
11
  chat_client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
12
 
13
  # Load pre-trained image classification model
14
  model = resnet50(pretrained=True)
@@ -21,15 +23,14 @@ transform = transforms.Compose([
21
  ])
22
 
23
  def search_wikipedia(query):
24
- try:
25
- summary = wikipedia.summary(query)
26
- return summary
27
- except wikipedia.exceptions.DisambiguationError as e:
28
- return f"Disambiguation error: {e}"
29
- except wikipedia.exceptions.PageError:
30
  return "No information found on that topic."
31
 
32
  def respond(message, history, system_message, max_tokens, temperature, top_p):
 
33
  search_response = search_wikipedia(message)
34
 
35
  # Prepare the chat messages
@@ -62,53 +63,57 @@ def classify_image(image):
62
  _, predicted = torch.max(output, 1)
63
  return f"Predicted class index: {predicted.item()}"
64
 
65
- # Placeholder functions for video generation and classification
66
- def generate_video(video):
67
- return video # Placeholder: Just returns the input video for now
68
-
69
- def classify_video(video):
70
- return "Video classification logic not implemented." # Placeholder
71
-
72
  # Gradio interface setup using Blocks
73
  with gr.Blocks() as demo:
74
  gr.Markdown("## Multi-Functional AI Interface")
75
 
76
- with gr.Row():
77
- with gr.Column():
78
- system_message = gr.Textbox(value="You are a friendly Chatbot named Tirth.", label="System message")
79
- max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
80
- temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
81
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
 
 
 
 
 
 
82
 
83
- with gr.Column():
84
- chat_output = gr.Chatbot(label="Chat History")
85
- user_input = gr.Textbox(placeholder="Type your message here...", label="Your Message")
86
- submit_btn = gr.Button("Send")
87
 
88
- # Image Classification
89
- image_input = gr.Image(type="pil", label="Upload an Image for Classification")
90
- classify_btn = gr.Button("Classify Image")
91
- classification_output = gr.Textbox(label="Classification Result")
92
-
93
- # Video Generation
94
- video_input = gr.Video(label="Upload a Video for Generation")
95
- generate_video_btn = gr.Button("Generate Video")
96
- video_output = gr.Video(label="Generated Video")
97
-
98
- # Video Classification
99
- video_class_input = gr.Video(label="Upload a Video for Classification")
100
- classify_video_btn = gr.Button("Classify Video")
101
- video_classification_output = gr.Textbox(label="Video Classification Result")
 
 
 
 
 
102
 
103
- def on_submit(message, history):
104
- response, search_response = respond(message, history, system_message.value, max_tokens.value, temperature.value, top_p.value)
105
- return history + [(message, response)], search_response
 
 
 
 
 
106
 
107
- # Connect the button to the submission function
108
- submit_btn.click(on_submit, inputs=[user_input, chat_output], outputs=[chat_output, gr.Textbox(label="Wikipedia Summary")])
109
- classify_btn.click(classify_image, inputs=image_input, outputs=classification_output)
110
- generate_video_btn.click(generate_video, inputs=video_input, outputs=video_output)
111
- classify_video_btn.click(classify_video, inputs=video_class_input, outputs=video_classification_output)
112
 
113
  if __name__ == "__main__":
114
  demo.launch()
 
9
 
10
  # Initialize inference client for chat
11
  chat_client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
12
+ # Initialize Wikipedia API
13
+ wiki_wiki = wikipediaapi.Wikipedia('en')
14
 
15
  # Load pre-trained image classification model
16
  model = resnet50(pretrained=True)
 
23
  ])
24
 
25
  def search_wikipedia(query):
26
+ page = wiki_wiki.page(query)
27
+ if page.exists():
28
+ return page.summary
29
+ else:
 
 
30
  return "No information found on that topic."
31
 
32
  def respond(message, history, system_message, max_tokens, temperature, top_p):
33
+ # Search Wikipedia for information
34
  search_response = search_wikipedia(message)
35
 
36
  # Prepare the chat messages
 
63
  _, predicted = torch.max(output, 1)
64
  return f"Predicted class index: {predicted.item()}"
65
 
 
 
 
 
 
 
 
66
  # Gradio interface setup using Blocks
67
  with gr.Blocks() as demo:
68
  gr.Markdown("## Multi-Functional AI Interface")
69
 
70
+ with gr.Tab("Chatbot with Wikipedia Search"):
71
+ with gr.Row():
72
+ with gr.Column():
73
+ system_message = gr.Textbox(value="You are a friendly Chatbot named Tirth.", label="System message")
74
+ max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
75
+ temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
76
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
77
+
78
+ with gr.Column():
79
+ chat_output = gr.Chatbot(label="Chat History")
80
+ user_input = gr.Textbox(placeholder="Type your message here...", label="Your Message")
81
+ submit_btn = gr.Button("Send")
82
 
83
+ def on_submit(message, history):
84
+ response, search_response = respond(message, history, system_message.value, max_tokens.value, temperature.value, top_p.value)
85
+ return history + [(message, response)], search_response
 
86
 
87
+ submit_btn.click(on_submit, inputs=[user_input, chat_output], outputs=[chat_output, gr.Textbox(label="Wikipedia Summary")])
88
+
89
+ with gr.Tab("Image Classification"):
90
+ image_input = gr.Image(type="pil", label="Upload an Image")
91
+ classify_btn = gr.Button("Classify Image")
92
+ classification_output = gr.Textbox(label="Classification Result")
93
+
94
+ classify_btn.click(classify_image, inputs=image_input, outputs=classification_output)
95
+
96
+ with gr.Tab("Video Generation"):
97
+ video_input = gr.Video(label="Upload a Video")
98
+ generate_video_btn = gr.Button("Generate Video")
99
+ video_output = gr.Video(label="Generated Video")
100
+
101
+ # Placeholder for video generation logic (implement as needed)
102
+ def generate_video(video):
103
+ return video # Just returns the input video for now
104
+
105
+ generate_video_btn.click(generate_video, inputs=video_input, outputs=video_output)
106
 
107
+ with gr.Tab("Video Classification"):
108
+ video_class_input = gr.Video(label="Upload a Video for Classification")
109
+ classify_video_btn = gr.Button("Classify Video")
110
+ video_classification_output = gr.Textbox(label="Video Classification Result")
111
+
112
+ # Placeholder for video classification logic (implement as needed)
113
+ def classify_video(video):
114
+ return "Video classification logic not implemented." # Placeholder
115
 
116
+ classify_video_btn.click(classify_video, inputs=video_class_input, outputs=video_classification_output)
 
 
 
 
117
 
118
  if __name__ == "__main__":
119
  demo.launch()