TirthGPT commited on
Commit
69312ef
1 Parent(s): fb42732

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -5
app.py CHANGED
@@ -1,12 +1,27 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import wikipediaapi
 
 
 
 
 
4
 
5
  # Initialize inference client for chat
6
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
7
  # Initialize Wikipedia API
8
  wiki_wiki = wikipediaapi.Wikipedia('en')
9
 
 
 
 
 
 
 
 
 
 
 
10
  def search_wikipedia(query):
11
  page = wiki_wiki.page(query)
12
  if page.exists():
@@ -30,7 +45,7 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
30
 
31
  # Generate response from chat model
32
  response = ""
33
- for message in client.chat_completion(
34
  messages,
35
  max_tokens=max_tokens,
36
  stream=True,
@@ -41,28 +56,60 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
41
  response += token
42
  yield response, search_response # Return both responses
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # Gradio interface setup using Blocks
45
  with gr.Blocks() as demo:
46
- gr.Markdown("## Chatbot with Wikipedia Search")
 
47
  with gr.Row():
48
  with gr.Column():
49
  system_message = gr.Textbox(value="You are a friendly Chatbot named Tirth.", label="System message")
50
  max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
51
  temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
52
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
53
-
54
  with gr.Column():
55
  chat_output = gr.Chatbox(label="Chat History")
56
  user_input = gr.Textbox(placeholder="Type your message here...", label="Your Message")
57
  submit_btn = gr.Button("Send")
58
 
59
- # Function to handle button click
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def on_submit(message, history):
61
  response, search_response = respond(message, history, system_message.value, max_tokens.value, temperature.value, top_p.value)
62
  return history + [(message, response)], search_response
63
 
64
  # Connect the button to the submission function
65
  submit_btn.click(on_submit, inputs=[user_input, chat_output], outputs=[chat_output, gr.Textbox(label="Wikipedia Summary")])
 
 
 
66
 
67
  if __name__ == "__main__":
68
  demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import wikipediaapi
4
+ from PIL import Image
5
+ import requests
6
+ import torch
7
+ from torchvision import transforms
8
+ from torchvision.models import resnet50
9
 
10
  # Initialize inference client for chat
11
+ chat_client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
12
  # Initialize Wikipedia API
13
  wiki_wiki = wikipediaapi.Wikipedia('en')
14
 
15
+ # Load pre-trained image classification model
16
+ model = resnet50(pretrained=True)
17
+ model.eval()
18
+ transform = transforms.Compose([
19
+ transforms.Resize(256),
20
+ transforms.CenterCrop(224),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
23
+ ])
24
+
25
  def search_wikipedia(query):
26
  page = wiki_wiki.page(query)
27
  if page.exists():
 
45
 
46
  # Generate response from chat model
47
  response = ""
48
+ for message in chat_client.chat_completion(
49
  messages,
50
  max_tokens=max_tokens,
51
  stream=True,
 
56
  response += token
57
  yield response, search_response # Return both responses
58
 
59
+ def classify_image(image):
60
+ image = transform(image).unsqueeze(0)
61
+ with torch.no_grad():
62
+ output = model(image)
63
+ _, predicted = torch.max(output, 1)
64
+ return f"Predicted class index: {predicted.item()}"
65
+
66
+ # Placeholder functions for video generation and classification
67
+ def generate_video(video):
68
+ return video # Placeholder: Just returns the input video for now
69
+
70
+ def classify_video(video):
71
+ return "Video classification logic not implemented." # Placeholder
72
+
73
  # Gradio interface setup using Blocks
74
  with gr.Blocks() as demo:
75
+ gr.Markdown("## Multi-Functional AI Interface")
76
+
77
  with gr.Row():
78
  with gr.Column():
79
  system_message = gr.Textbox(value="You are a friendly Chatbot named Tirth.", label="System message")
80
  max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
81
  temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
82
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
83
+
84
  with gr.Column():
85
  chat_output = gr.Chatbox(label="Chat History")
86
  user_input = gr.Textbox(placeholder="Type your message here...", label="Your Message")
87
  submit_btn = gr.Button("Send")
88
 
89
+ # Image Classification
90
+ image_input = gr.Image(type="pil", label="Upload an Image for Classification")
91
+ classify_btn = gr.Button("Classify Image")
92
+ classification_output = gr.Textbox(label="Classification Result")
93
+
94
+ # Video Generation
95
+ video_input = gr.Video(label="Upload a Video for Generation")
96
+ generate_video_btn = gr.Button("Generate Video")
97
+ video_output = gr.Video(label="Generated Video")
98
+
99
+ # Video Classification
100
+ video_class_input = gr.Video(label="Upload a Video for Classification")
101
+ classify_video_btn = gr.Button("Classify Video")
102
+ video_classification_output = gr.Textbox(label="Video Classification Result")
103
+
104
  def on_submit(message, history):
105
  response, search_response = respond(message, history, system_message.value, max_tokens.value, temperature.value, top_p.value)
106
  return history + [(message, response)], search_response
107
 
108
  # Connect the button to the submission function
109
  submit_btn.click(on_submit, inputs=[user_input, chat_output], outputs=[chat_output, gr.Textbox(label="Wikipedia Summary")])
110
+ classify_btn.click(classify_image, inputs=image_input, outputs=classification_output)
111
+ generate_video_btn.click(generate_video, inputs=video_input, outputs=video_output)
112
+ classify_video_btn.click(classify_video, inputs=video_class_input, outputs=video_classification_output)
113
 
114
  if __name__ == "__main__":
115
  demo.launch()