🌐 AI搜索 & 代理 主页
Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e3bea27
fine-tuning text classification in progress
santiatpml Jan 31, 2024
c4cf332
More commit messages
santiatpml Feb 1, 2024
fb7cc2a
Working text classification with dataset args and training args
santiatpml Feb 6, 2024
5584487
finetuing with text dataset enum to handle different tasks
santiatpml Feb 7, 2024
82cb4f7
text pair classification task support
santiatpml Feb 7, 2024
c10de47
saving model after training
santiatpml Feb 7, 2024
63ee09b
removed device to cpu
santiatpml Feb 7, 2024
865ae28
updated transforemrs
Feb 8, 2024
097a8cf
Working e2e finetunig for two tasks
Feb 8, 2024
2dd50e6
Integration with huggingface hub and wandb
Feb 9, 2024
6ac8722
Conversation dataset + training placeholder
Feb 13, 2024
1e40cd8
Updated rust to fix failing tests
Feb 13, 2024
312d893
working version of conversation with lora + load 8bit + hf hub
Feb 13, 2024
afc2e93
Tested llama2-7b finetuning
Feb 22, 2024
22ee5c7
pypgrx first working version
Feb 27, 2024
97d455d
refactoring finetuning code to add callbacks
santiatpml Feb 27, 2024
b700944
fixed merge conflicts
santiatpml Mar 5, 2024
65d2f8b
Refactored finetuning + conversation + pgml callbacks
Mar 2, 2024
5f1b5f4
removed wandb dependency
Mar 4, 2024
08084bf
removed local pypgrx from requirements
Mar 4, 2024
dc0c6ee
removed maturin from requirements
Mar 4, 2024
421af8f
removed flash attn
Mar 4, 2024
4bbca96
Added indent for info display
Mar 5, 2024
3db857c
Updated readme with LLM fine-tuning for text classification
santiatpml Mar 7, 2024
7cbee43
README updates
santiatpml Mar 7, 2024
9284cf1
Added a tutorial for 9 classes - draft 1
santiatpml Mar 8, 2024
66c65c8
README updates
santiatpml Mar 8, 2024
5759ee3
Moved python functions (#1374)
SilasMarvin Mar 18, 2024
b539168
README updates
santiatpml Mar 19, 2024
31215b8
migrations and removed pypgrx
santiatpml Mar 20, 2024
dae6b74
Added r_log to take log level and message
santiatpml Mar 20, 2024
dae5ffc
Updated version and requirements
Mar 22, 2024
435f5bd
Changed version 2.8.3
Mar 22, 2024
aeb2683
README updates for conversation task fine-tuning using lora
santiatpml Mar 22, 2024
e5221cc
minor readme updates
santiatpml Mar 26, 2024
6db147e
added new line
santiatpml Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
text pair classification task support
  • Loading branch information
santiatpml committed Mar 5, 2024
commit 82cb4f7af4706c4f06a694305ca5a9d83bcd60a3
31 changes: 30 additions & 1 deletion pgml-extension/src/bindings/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use pyo3::types::PyTuple;
use serde_json::Value;

use crate::create_pymodule;
use crate::orm::{Task, TextClassificationDataset};
use crate::orm::{Task, TextClassificationDataset, TextPairClassificationDataset};

use super::TracebackError;

Expand Down Expand Up @@ -106,6 +106,35 @@ pub fn finetune_text_classification(task: &Task, dataset: TextClassificationData
output.extract(py).format_traceback(py)
})
}

pub fn finetune_text_pair_classification(task: &Task, dataset: TextPairClassificationDataset, hyperparams: &JsonB, path: &Path) -> Result<HashMap<String, f64>> {
let task = task.to_string();
let hyperparams = serde_json::to_string(&hyperparams.0)?;

Python::with_gil(|py| -> Result<HashMap<String, f64>> {
let tune = get_module!(PY_MODULE).getattr(py, "finetune_text_pair_classification").format_traceback(py)?;
let path = path.to_string_lossy();
let output = tune
.call1(
py,
(
&task,
&hyperparams,
path.as_ref(),
dataset.text1_train,
dataset.text1_test,
dataset.text2_train,
dataset.text2_test,
dataset.class_train,
dataset.class_test,
),
)
.format_traceback(py)?;

output.extract(py).format_traceback(py)
})
}

pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<String>> {
Python::with_gil(|py| -> Result<Vec<String>> {
let generate = get_module!(PY_MODULE).getattr(py, "generate").format_traceback(py)?;
Expand Down
90 changes: 90 additions & 0 deletions pgml-extension/src/bindings/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,96 @@ def tokenize_function(example):
log.error(e)
log.info("Training started")

# Train
trainer.train()
metrics = {"loss" : 0.0}
return metrics

def finetune_text_pair_classification(task, hyperparams, path, text1_train, text1_test, text2_train, text2_test, class_train, class_test):
# Get model and tokenizer
hyperparams = orjson.loads(hyperparams)
model_name = hyperparams.pop("model_name")
tokenizer = AutoTokenizer.from_pretrained(model_name)
classes = list(set(class_train))
num_classes = len(classes)

id2label = {}
label2id = {}
for id, label in enumerate(classes):
label2id[label] = id
id2label[id] = label

model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_classes, id2label=id2label, label2id=label2id
)

model.config.id2label = id2label
model.config.label2id = label2id

y_train_label = [label2id[_class] for _class in class_train]
y_test_label = [label2id[_class] for _class in class_test]

# Prepare dataset
train_dataset = datasets.Dataset.from_dict(
{
"text1": text1_train,
"text2" : text2_train,
"label": y_train_label,
}
)
test_dataset = datasets.Dataset.from_dict(
{
"text1": text1_test,
"text2": text2_test,
"label": y_test_label,
}
)
# tokenization function
def tokenize_function(example):
tokenized_example = tokenizer(
example["text1"],
example["text2"],
padding=True,
truncation=True,
return_tensors="pt"
)
return tokenized_example

# Generate tokens
train_tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
test_tokenized_datasets = test_dataset.map(tokenize_function, batched=True)
log.info("Tokenization done")
log.info("Train dataset")
log.info(train_tokenized_datasets[0:2])
log.info("Test dataset")
log.info(test_tokenized_datasets[0:2])
# Data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Training Args
log.info("Training args setup started path=%s"%path)
training_args=TrainingArguments(output_dir="/tmp/postgresml/models/", logging_dir="/tmp/postgresml/runs", **hyperparams["training_args"])
log.info("Trainer setup done")
# Trainer
log.info(model)
log.info(training_args)
log.info(train_tokenized_datasets)
log.info(test_tokenized_datasets)
log.info(tokenizer)
log.info(data_collator)
try:
trainer = Trainer(
model=model.to("cpu"),
args=training_args,
train_dataset=train_tokenized_datasets,
eval_dataset=test_tokenized_datasets,
tokenizer=tokenizer,
data_collator=data_collator,
)
except Exception as e:
log.error(e)
log.info("Training started")

# Train
trainer.train()
metrics = {"loss" : 0.0}
Expand Down
46 changes: 37 additions & 9 deletions pgml-extension/src/orm/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ impl Dataset {
}
}

pub enum TextDatasetType {
TextClassification(TextClassificationDataset),
TextPairClassification(TextPairClassificationDataset),
}

impl TextDatasetType {
pub fn num_features(&self) -> usize {
match self {
TextDatasetType::TextClassification(dataset) => dataset.num_features,
TextDatasetType::TextPairClassification(dataset) => dataset.num_features,
}
}
}

// TextClassificationDataset
pub struct TextClassificationDataset {
pub text_train: Vec<String>,
Expand All @@ -86,24 +100,38 @@ impl Display for TextClassificationDataset {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(
f,
"TextClassificationDataset {{ num_features: {}, num_labels: {}, num_distinct_labels: {}, num_rows: {}, num_train_rows: {}, num_test_rows: {} }}",
self.num_features, self.num_labels, self.num_distinct_labels, self.num_rows, self.num_train_rows, self.num_test_rows,
"TextClassificationDataset {{ num_distinct_labels: {}, num_rows: {}, num_train_rows: {}, num_test_rows: {} }}",
self.num_distinct_labels, self.num_rows, self.num_train_rows, self.num_test_rows,
)
}
}

pub enum TextDatasetType {
TextClassification(TextClassificationDataset),
pub struct TextPairClassificationDataset {
pub text1_train: Vec<String>,
pub text2_train: Vec<String>,
pub class_train: Vec<String>,
pub text1_test: Vec<String>,
pub text2_test: Vec<String>,
pub class_test: Vec<String>,
pub num_features: usize,
pub num_labels: usize,
pub num_rows: usize,
pub num_train_rows: usize,
pub num_test_rows: usize,
pub num_distinct_labels: usize,
}

impl TextDatasetType {
pub fn num_features(&self) -> usize {
match self {
TextDatasetType::TextClassification(dataset) => dataset.num_features,
}
impl Display for TextPairClassificationDataset {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(
f,
"TextPairClassificationDataset {{ num_distinct_labels: {}, num_rows: {}, num_train_rows: {}, num_test_rows: {} }}",
self.num_distinct_labels, self.num_rows, self.num_train_rows, self.num_test_rows,
)
}
}


fn drop_table_if_exists(table_name: &str) {
// Avoid the existence for DROP TABLE IF EXISTS warning by checking the schema for the table first
let table_count = Spi::get_one_with_args::<i64>(
Expand Down
1 change: 1 addition & 0 deletions pgml-extension/src/orm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub use algorithm::Algorithm;
pub use dataset::Dataset;
pub use dataset::TextDatasetType;
pub use dataset::TextClassificationDataset;
pub use dataset::TextPairClassificationDataset;
pub use model::Model;
pub use project::Project;
pub use runtime::Runtime;
Expand Down
11 changes: 10 additions & 1 deletion pgml-extension/src/orm/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ impl Model {
// let dataset = snapshot.text_classification_dataset(dataset_args);
let dataset = if project.task == Task::text_classification {
TextDatasetType::TextClassification(snapshot.text_classification_dataset(dataset_args))
} else if project.task == Task::text_pair_classification {
TextDatasetType::TextPairClassification(snapshot.text_pair_classification_dataset(dataset_args))
} else {
TextDatasetType::TextClassification(snapshot.text_classification_dataset(dataset_args))
panic!("Unsupported task for finetuning")
};

// Create the model record.
Expand Down Expand Up @@ -229,6 +231,13 @@ impl Model {
};

}
TextDatasetType::TextPairClassification(dataset) => {
metrics = match transformers::finetune_text_pair_classification(&project.task, dataset, &model.hyperparams, &path) {
Ok(metrics) => metrics,
Err(e) => error!("{e}"),
};

}
};

model.metrics = Some(JsonB(json!(metrics)));
Expand Down
Loading