Commit
•
e525bd5
1
Parent(s):
5ff9afc
fix: update example
Browse files
app.py
CHANGED
@@ -50,7 +50,7 @@ def load_examples():
|
|
50 |
|
51 |
|
52 |
# Create Gradio examples
|
53 |
-
examples = load_examples()
|
54 |
|
55 |
|
56 |
def process_fields(fields):
|
@@ -112,34 +112,44 @@ from gradio_client import Client
|
|
112 |
import argilla as rg
|
113 |
|
114 |
# Initialize Argilla client
|
115 |
-
|
|
|
116 |
api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"]
|
117 |
)
|
118 |
|
119 |
# Load the dataset
|
120 |
-
dataset =
|
121 |
-
|
122 |
-
#
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
|
|
126 |
payload = {
|
127 |
-
"records": [
|
128 |
-
"fields": [
|
129 |
-
"question":
|
|
|
|
|
130 |
}
|
131 |
|
132 |
-
|
133 |
-
client = Client("davidberenstein1957/distilabel-argilla-labeller")
|
134 |
-
|
135 |
-
result = client.predict(
|
136 |
-
records=json.dumps(payload["records"]),
|
137 |
-
example_records=json.dumps(payload["example_records"]),
|
138 |
-
fields=json.dumps(payload["fields"]),
|
139 |
-
question=json.dumps(payload["question"]),
|
140 |
-
api_name="/predict"
|
141 |
-
)
|
142 |
-
|
143 |
```
|
144 |
"""
|
145 |
|
|
|
50 |
|
51 |
|
52 |
# Create Gradio examples
|
53 |
+
examples = load_examples()[:1]
|
54 |
|
55 |
|
56 |
def process_fields(fields):
|
|
|
112 |
import argilla as rg
|
113 |
|
114 |
# Initialize Argilla client
|
115 |
+
gradio_client = Client("davidberenstein1957/distilabel-argilla-labeller")
|
116 |
+
argilla_client = rg.Argilla(
|
117 |
api_key=os.environ["ARGILLA_API_KEY"], api_url=os.environ["ARGILLA_API_URL"]
|
118 |
)
|
119 |
|
120 |
# Load the dataset
|
121 |
+
dataset = argilla_client.datasets(name="my_dataset", workspace="my_workspace")
|
122 |
+
|
123 |
+
# Get the field and question
|
124 |
+
field = dataset.settings.fields["text"]
|
125 |
+
question = dataset.settings.questions["sentiment"]
|
126 |
+
|
127 |
+
# Get completed and pending records
|
128 |
+
completed_records_filter = rg.Filter(("status", "==", "completed"))
|
129 |
+
pending_records_filter = rg.Filter(("status", "==", "pending"))
|
130 |
+
example_records = list(
|
131 |
+
dataset.records(
|
132 |
+
query=rg.Query(filter=completed_records_filter),
|
133 |
+
limit=5,
|
134 |
+
)
|
135 |
+
)
|
136 |
+
some_pending_records = list(
|
137 |
+
dataset.records(
|
138 |
+
query=rg.Query(filter=pending_records_filter),
|
139 |
+
limit=5,
|
140 |
+
)
|
141 |
+
)
|
142 |
|
143 |
+
# Process the records
|
144 |
payload = {
|
145 |
+
"records": [record.to_dict() for record in some_pending_records],
|
146 |
+
"fields": [field.serialize()],
|
147 |
+
"question": question.serialize(),
|
148 |
+
"example_records": [record.to_dict() for record in example_records],
|
149 |
+
"api_name": "/predict",
|
150 |
}
|
151 |
|
152 |
+
response = gradio_client.predict(**payload)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
```
|
154 |
"""
|
155 |
|