🌐 AI搜索 & 代理 主页
Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3ff2f07
Added OpenSourceAI and conversational support in the extension
SilasMarvin Nov 22, 2023
1ca5dc8
Clean up errors and guard rails around conversational api
SilasMarvin Nov 22, 2023
b6b7ec6
Not great, pivoting to better solution after talking with Santi
SilasMarvin Nov 28, 2023
be4788a
Working conversational everything
SilasMarvin Nov 28, 2023
719fdc5
Fixed typo
SilasMarvin Nov 28, 2023
23ef2a3
Working non streaming open source ai replacement
SilasMarvin Nov 28, 2023
dedf434
Remove outdated comment
SilasMarvin Nov 28, 2023
f3d8e1f
Working OpenSourceAI with both sync and async options
SilasMarvin Nov 29, 2023
2969c89
Cleaned up and tested well
SilasMarvin Nov 29, 2023
3b26743
Completely removed the GPTQ pipeline as it is no longer necessary
SilasMarvin Nov 29, 2023
cf1afc6
Removed unnecessary python imports
SilasMarvin Nov 29, 2023
accd159
Removed universal debugger output
SilasMarvin Nov 30, 2023
e5eccec
Finalized models in SDK for open source ai
SilasMarvin Dec 1, 2023
95e1e9a
Updated to work with hugging face tokens
SilasMarvin Dec 1, 2023
c80817b
Finalized models in SDK for open source ai
SilasMarvin Dec 1, 2023
9a3ca91
Removed unnecessary comment
SilasMarvin Dec 1, 2023
4eb88f8
Put back the GGML pipeline and removed the GPTQ pipeline earlier comm…
SilasMarvin Dec 1, 2023
93e7ffb
Changed some error messages
SilasMarvin Dec 1, 2023
73ee33a
Added migration for 2.8.1
SilasMarvin Dec 1, 2023
0201880
Working migration file
SilasMarvin Dec 1, 2023
47c18d6
Really working migration file
SilasMarvin Dec 1, 2023
fb3f7f7
Bumped version
SilasMarvin Dec 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Not great, pivoting to better solution after talking with Santi
  • Loading branch information
SilasMarvin committed Nov 29, 2023
commit b6b7ec600bc86f7e7a38c4fadf35423f06263b7e
33 changes: 15 additions & 18 deletions pgml-extension/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,17 +682,14 @@ pub fn transform_conversational_string(
pub fn transform_stream_json(
task: JsonB,
args: default!(JsonB, "'{}'"),
input: default!(&str, "''"),
inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"),
cache: default!(bool, false),
) -> SetOfIterator<'static, String> {
) -> SetOfIterator<'static, JsonB> {
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
let python_iter = crate::bindings::transformers::transform_stream_iterator(
&task.0,
&args.0,
input.to_string(),
)
.map_err(|e| error!("{e}"))
.unwrap();
let python_iter =
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs)
.map_err(|e| error!("{e}"))
.unwrap();
SetOfIterator::new(python_iter)
}

Expand All @@ -702,13 +699,13 @@ pub fn transform_stream_json(
pub fn transform_stream_string(
task: String,
args: default!(JsonB, "'{}'"),
input: default!(&str, "''"),
inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"),
cache: default!(bool, false),
) -> SetOfIterator<'static, String> {
) -> SetOfIterator<'static, JsonB> {
let task_json = json!({ "task": task });
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
let python_iter =
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input)
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs)
.map_err(|e| error!("{e}"))
.unwrap();
SetOfIterator::new(python_iter)
Expand All @@ -720,9 +717,9 @@ pub fn transform_stream_string(
pub fn transform_stream_conversational_json(
task: JsonB,
args: default!(JsonB, "'{}'"),
input: default!(JsonB, "'[]'::JSONB"),
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
cache: default!(bool, false),
) -> SetOfIterator<'static, String> {
) -> SetOfIterator<'static, JsonB> {
if !task.0["task"]
.as_str()
.is_some_and(|v| v == "conversational")
Expand All @@ -733,7 +730,7 @@ pub fn transform_stream_conversational_json(
}
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
let python_iter =
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input.0)
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs)
.map_err(|e| error!("{e}"))
.unwrap();
SetOfIterator::new(python_iter)
Expand All @@ -745,9 +742,9 @@ pub fn transform_stream_conversational_json(
pub fn transform_stream_conversational_string(
task: String,
args: default!(JsonB, "'{}'"),
input: default!(JsonB, "'[]'::JSONB"),
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
cache: default!(bool, false),
) -> SetOfIterator<'static, String> {
) -> SetOfIterator<'static, JsonB> {
if task != "conversational" {
error!(
"JSONB inputs for transformer_stream should only be used with a conversational task"
Expand All @@ -756,7 +753,7 @@ pub fn transform_stream_conversational_string(
let task_json = json!({ "task": task });
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
let python_iter =
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input.0)
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs)
.map_err(|e| error!("{e}"))
.unwrap();
SetOfIterator::new(python_iter)
Expand Down
9 changes: 4 additions & 5 deletions pgml-extension/src/bindings/transformers/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use anyhow::Result;
use pgrx::*;
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::AsPyPointer;

create_pymodule!("/src/bindings/transformers/transformers.py");

Expand All @@ -24,17 +23,17 @@ impl TransformStreamIterator {
}

impl Iterator for TransformStreamIterator {
type Item = String;
type Item = JsonB;
fn next(&mut self) -> Option<Self::Item> {
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
Python::with_gil(|py| -> Result<Option<String>, PyErr> {
Python::with_gil(|py| -> Result<Option<JsonB>, PyErr> {
let code = "next(python_iter)";
let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?;
if res.is_none() {
Ok(None)
} else {
let res: String = res.extract()?;
Ok(Some(res))
let res: Vec<String> = res.extract()?;
Ok(Some(JsonB(serde_json::to_value(res).unwrap())))
}
})
.map_err(|e| error!("{e}"))
Expand Down
124 changes: 84 additions & 40 deletions pgml-extension/src/bindings/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
TrainingArguments,
Trainer,
TextStreamer,
Conversation
Conversation,
)
from threading import Thread
from typing import Optional
Expand Down Expand Up @@ -95,24 +95,34 @@ def ensure_device(kwargs):
else:
kwargs["device"] = "cpu"

# A copy of HuggingFace's with small changes in the __next__ to not raise an exception
class TextIteratorStreamer(TextStreamer):
def __init__(
self, tokenizer, skip_prompt = False, timeout = None, **decode_kwargs
):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.text_queue = queue.Queue()
self.stop_signal = None

# Follows BaseStreamer template from transformers library
class TextIteratorStreamer:
def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.timeout = timeout
self.decode_kwargs = decode_kwargs
self.next_tokens_are_prompt = True
self.stop_signal = None
self.text_queue = queue.Queue()

def on_finalized_text(self, text: str, stream_end: bool = False):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
self.text_queue.put(text, timeout=self.timeout)
if stream_end:
self.text_queue.put(self.stop_signal, timeout=self.timeout)
def put(self, value):
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
# Can't batch this decode
decoded_values = []
for v in value:
decoded_values.append(self.tokenizer.decode(v, **self.decode_kwargs))
self.text_queue.put(decoded_values, self.timeout)

def end(self):
self.next_tokens_are_prompt = True
self.text_queue.put(self.stop_signal, self.timeout)

def __iter__(self):
return self
self

def __next__(self):
value = self.text_queue.get(timeout=self.timeout)
Expand Down Expand Up @@ -215,6 +225,18 @@ def __init__(self, model_name, **kwargs):
# to the model constructor, so we construct the model/tokenizer manually if possible,
# but that is only possible when the task is passed in, since if you pass the model
# to the pipeline constructor, the task will no longer be inferred from the default...

# We want to create a text-generation pipeline if it is a conversational task
self.conversational = False
if "task" in kwargs and kwargs["task"] == "conversational":
self.conversational = True
kwargs["task"] = "text-generation"

# Tokens can either be left or right padded depending on the architecture
padding_side = "right"
if "padding_side" in kwargs:
padding_side = kwargs.pop("padding_side")

if (
"task" in kwargs
and model_name is not None
Expand All @@ -224,8 +246,7 @@ def __init__(self, model_name, **kwargs):
"question-answering",
"summarization",
"translation",
"text-generation",
"conversational"
"text-generation"
]
):
self.task = kwargs.pop("task")
Expand All @@ -240,56 +261,75 @@ def __init__(self, model_name, **kwargs):
)
elif self.task == "summarization" or self.task == "translation":
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs)
elif self.task == "text-generation" or self.task == "conversational":
elif self.task == "text-generation":
self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
else:
raise PgMLException(f"Unhandled task: {self.task}")

if "use_auth_token" in kwargs:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, use_auth_token=kwargs["use_auth_token"]
model_name, use_auth_token=kwargs["use_auth_token"], padding_side=padding_side
)
else:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side)

self.pipe = transformers.pipeline(
self.task,
model=self.model,
tokenizer=self.tokenizer,
)
else:
self.pipe = transformers.pipeline(**kwargs)
self.pipe = transformers.pipeline(**kwargs, padding_side=padding_side)
self.tokenizer = self.pipe.tokenizer
self.task = self.pipe.task
self.model = self.pipe.model
if self.pipe.tokenizer is None:
self.pipe.tokenizer = AutoTokenizer.from_pretrained(
self.model.name_or_path
)
self.tokenizer = self.pipe.tokenizer

# Make sure we set the pad token if it does not exist
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

def stream(self, inputs, **kwargs):
streamer = None
generation_kwargs = None
if self.task == "conversational":
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
inputs = tokenized_chat = self.tokenizer.apply_chat_template(inputs, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(self.model.device)
generation_kwargs = dict(inputs=inputs, streamer=streamer, **kwargs)
# Conversational does not work right now with left padded tokenizers. At least for gpt2, the apply_chat_template breaks it
if self.conversational:
streamer = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True
)
templated_inputs = []
for input in inputs:
templated_inputs.append(
self.tokenizer.apply_chat_template(
input, add_generation_prompt=True, tokenize=False
)
)
inputs = self.tokenizer(
templated_inputs, return_tensors="pt", padding=True
).to(self.model.device)
generation_kwargs = dict(inputs, streamer=streamer, **kwargs)
else:
streamer = TextIteratorStreamer(self.tokenizer)
inputs = self.tokenizer([inputs], return_tensors="pt").to(self.model.device)
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(
self.model.device
)
generation_kwargs = dict(inputs, streamer=streamer, **kwargs)
print("\n\n", file=sys.stderr)
print(inputs, file=sys.stderr)
print("\n\n", file=sys.stderr)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
return streamer

def __call__(self, inputs, **kwargs):
if self.task == "conversational":
outputs = []
for conversation in inputs:
conversation = Conversation(conversation)
conversation = self.pipe(conversation, **kwargs)
outputs.append(conversation.generated_responses[-1])
return outputs
if self.conversational:
templated_inputs = []
for input in inputs:
templated_inputs.append(
self.tokenizer.apply_chat_template(
input, add_generation_prompt=True, tokenize=False
)
)
return self.pipe(templated_inputs, return_full_text=False, **kwargs)
else:
return self.pipe(inputs, **kwargs)

Expand Down Expand Up @@ -320,7 +360,11 @@ def create_pipeline(task):
lower = None
if lower and ("-ggml" in lower or "-gguf" in lower):
pipe = GGMLPipeline(model_name, **task)
elif lower and "-gptq" in lower and not (model_type == "mistral" or model_type == "llama"):
elif (
lower
and "-gptq" in lower
and not (model_type == "mistral" or model_type == "llama")
):
pipe = GPTQPipeline(model_name, **task)
else:
try:
Expand Down