|
""" |
|
Build txtai workflows. |
|
|
|
Based on this example: https://github.com/neuml/txtai/blob/master/examples/workflows.py |
|
""" |
|
|
|
import os |
|
import re |
|
|
|
import nltk |
|
|
|
import pandas as pd |
|
import streamlit as st |
|
|
|
from txtai.embeddings import Documents, Embeddings |
|
from txtai.pipeline import Segmentation, Summary, Tabular, Translation |
|
from txtai.workflow import ServiceTask, Task, UrlTask, Workflow |
|
|
|
|
|
class Application: |
|
""" |
|
Streamlit application. |
|
""" |
|
|
|
def __init__(self): |
|
""" |
|
Creates a new Streamlit application. |
|
""" |
|
|
|
|
|
self.components = {} |
|
|
|
|
|
self.pipelines = {} |
|
|
|
|
|
self.workflow = [] |
|
|
|
|
|
self.embeddings = None |
|
self.documents = None |
|
self.data = None |
|
|
|
def number(self, label): |
|
""" |
|
Extracts a number from a text input field. |
|
|
|
Args: |
|
label: label to use for text input field |
|
|
|
Returns: |
|
numeric input |
|
""" |
|
|
|
value = st.sidebar.text_input(label) |
|
return int(value) if value else None |
|
|
|
def split(self, text): |
|
""" |
|
Splits text on commas and returns a list. |
|
|
|
Args: |
|
text: input text |
|
|
|
Returns: |
|
list |
|
""" |
|
|
|
return [x.strip() for x in text.split(",")] |
|
|
|
def options(self, component): |
|
""" |
|
Extracts component settings into a component configuration dict. |
|
|
|
Args: |
|
component: component type |
|
|
|
Returns: |
|
dict with component settings |
|
""" |
|
|
|
options = {"type": component} |
|
|
|
st.sidebar.markdown("---") |
|
|
|
if component == "embeddings": |
|
st.sidebar.markdown("**Embeddings Index** \n*Index workflow output*") |
|
options["path"] = st.sidebar.text_input("Embeddings model path", value="sentence-transformers/nli-mpnet-base-v2") |
|
options["upsert"] = st.sidebar.checkbox("Upsert") |
|
|
|
elif component == "summary": |
|
st.sidebar.markdown("**Summary** \n*Abstractive text summarization*") |
|
options["path"] = st.sidebar.text_input("Model", value="sshleifer/distilbart-cnn-12-6") |
|
options["minlength"] = self.number("Min length") |
|
options["maxlength"] = self.number("Max length") |
|
|
|
elif component == "segment": |
|
st.sidebar.markdown("**Segment** \n*Split text into semantic units*") |
|
|
|
options["sentences"] = st.sidebar.checkbox("Split sentences") |
|
options["lines"] = st.sidebar.checkbox("Split lines") |
|
options["paragraphs"] = st.sidebar.checkbox("Split paragraphs") |
|
options["join"] = st.sidebar.checkbox("Join tokenized") |
|
options["minlength"] = self.number("Min section length") |
|
|
|
elif component == "service": |
|
options["url"] = st.sidebar.text_input("URL") |
|
options["method"] = st.sidebar.selectbox("Method", ["get", "post"], index=0) |
|
options["params"] = st.sidebar.text_input("URL parameters") |
|
options["batch"] = st.sidebar.checkbox("Run as batch", value=True) |
|
options["extract"] = st.sidebar.text_input("Subsection(s) to extract") |
|
|
|
if options["params"]: |
|
options["params"] = {key: None for key in self.split(options["params"])} |
|
if options["extract"]: |
|
options["extract"] = self.split(options["extract"]) |
|
|
|
elif component == "tabular": |
|
options["idcolumn"] = st.sidebar.text_input("Id columns") |
|
options["textcolumns"] = st.sidebar.text_input("Text columns") |
|
if options["textcolumns"]: |
|
options["textcolumns"] = self.split(options["textcolumns"]) |
|
|
|
elif component == "translate": |
|
st.sidebar.markdown("**Translate** \n*Machine translation*") |
|
options["target"] = st.sidebar.text_input("Target language code", value="en") |
|
|
|
return options |
|
|
|
def build(self, components): |
|
""" |
|
Builds a workflow using components. |
|
|
|
Args: |
|
components: list of components to add to workflow |
|
""" |
|
|
|
|
|
self.__init__() |
|
|
|
|
|
tasks = [] |
|
for component in components: |
|
component = dict(component) |
|
wtype = component.pop("type") |
|
self.components[wtype] = component |
|
|
|
if wtype == "embeddings": |
|
self.embeddings = Embeddings({**component}) |
|
self.documents = Documents() |
|
tasks.append(Task(self.documents.add, unpack=False)) |
|
|
|
elif wtype == "segment": |
|
self.pipelines[wtype] = Segmentation(**self.components["segment"]) |
|
tasks.append(Task(self.pipelines["segment"])) |
|
|
|
elif wtype == "service": |
|
tasks.append(ServiceTask(**self.components["service"])) |
|
|
|
elif wtype == "summary": |
|
self.pipelines[wtype] = Summary(component.pop("path")) |
|
tasks.append(Task(lambda x: self.pipelines["summary"](x, **self.components["summary"]))) |
|
|
|
elif wtype == "tabular": |
|
self.pipelines[wtype] = Tabular(**self.components["tabular"]) |
|
tasks.append(Task(self.pipelines["tabular"])) |
|
|
|
elif wtype == "translate": |
|
self.pipelines[wtype] = Translation() |
|
tasks.append(Task(lambda x: self.pipelines["translate"](x, **self.components["translate"]))) |
|
|
|
self.workflow = Workflow(tasks) |
|
|
|
def find(self, key): |
|
""" |
|
Lookup record from cached data by uid key. |
|
|
|
Args: |
|
key: uid to search for |
|
|
|
Returns: |
|
text for matching uid |
|
""" |
|
|
|
return [text for uid, text, _ in self.data if uid == key][0] |
|
|
|
def process(self, data): |
|
""" |
|
Processes the current application action. |
|
|
|
Args: |
|
data: input data |
|
""" |
|
|
|
if data and self.workflow: |
|
|
|
if self.documents: |
|
data = [(x, element, None) for x, element in enumerate(data)] |
|
|
|
|
|
for result in self.workflow(data): |
|
if not self.documents: |
|
st.write(result) |
|
|
|
|
|
if self.documents: |
|
|
|
self.data = list(self.documents) |
|
|
|
with st.spinner("Building embedding index...."): |
|
self.embeddings.index(self.documents) |
|
self.documents.close() |
|
|
|
|
|
self.documents, self.pipelines, self.workflow = None, None, None |
|
|
|
if self.embeddings and self.data: |
|
|
|
query = st.text_input("Query") |
|
limit = min(5, len(self.data)) |
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
table td:nth-child(1) { |
|
display: none |
|
} |
|
table th:nth-child(1) { |
|
display: none |
|
} |
|
table {text-align: left !important} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
if query: |
|
df = pd.DataFrame([{"content": self.find(uid), "score": score} for uid, score in self.embeddings.search(query, limit)]) |
|
st.table(df) |
|
|
|
def parse(self, data): |
|
""" |
|
Parse input data, splits on new lines depending on type of tasks and format of input. |
|
|
|
Args: |
|
data: input data |
|
|
|
Returns: |
|
parsed data |
|
""" |
|
|
|
if re.match(r"^(http|https|file):\/\/", data) or (self.workflow and isinstance(self.workflow.tasks[0], ServiceTask)): |
|
return [x for x in data.split("\n") if x] |
|
|
|
return [data] |
|
|
|
def run(self): |
|
""" |
|
Runs Streamlit application. |
|
""" |
|
|
|
st.sidebar.image("https://github.com/neuml/txtai/raw/master/logo.png", width=256) |
|
st.sidebar.markdown("# Workflow builder \n*Build and apply workflows to data* \n[GitHub](https://github.com/neuml/txtai) ") |
|
|
|
|
|
components = ["embeddings", "segment", "service", "summary", "tabular", "translate"] |
|
selected = st.sidebar.multiselect("Select components", components) |
|
|
|
|
|
components = [self.options(component) for component in selected] |
|
st.sidebar.markdown("---") |
|
|
|
with st.sidebar: |
|
|
|
build = st.button("Build", help="Build the workflow and run within this application") |
|
if build: |
|
with st.spinner("Building workflow...."): |
|
self.build(components) |
|
|
|
with st.expander("Data", expanded=not self.data): |
|
data = st.text_area("Input", height=10) |
|
|
|
|
|
data = self.parse(data) if data else data |
|
|
|
|
|
self.process(data) |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def create(): |
|
""" |
|
Creates and caches a Streamlit application. |
|
|
|
Returns: |
|
Application |
|
""" |
|
|
|
return Application() |
|
|
|
|
|
if __name__ == "__main__": |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
try: |
|
nltk.sent_tokenize("This is a test. Split") |
|
except: |
|
nltk.download("punkt") |
|
|
|
|
|
app = create() |
|
app.run() |
|
|