🌐 AI搜索 & 代理 主页
Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e3bea27
fine-tuning text classification in progress
santiatpml Jan 31, 2024
c4cf332
More commit messages
santiatpml Feb 1, 2024
fb7cc2a
Working text classification with dataset args and training args
santiatpml Feb 6, 2024
5584487
finetuing with text dataset enum to handle different tasks
santiatpml Feb 7, 2024
82cb4f7
text pair classification task support
santiatpml Feb 7, 2024
c10de47
saving model after training
santiatpml Feb 7, 2024
63ee09b
removed device to cpu
santiatpml Feb 7, 2024
865ae28
updated transforemrs
Feb 8, 2024
097a8cf
Working e2e finetunig for two tasks
Feb 8, 2024
2dd50e6
Integration with huggingface hub and wandb
Feb 9, 2024
6ac8722
Conversation dataset + training placeholder
Feb 13, 2024
1e40cd8
Updated rust to fix failing tests
Feb 13, 2024
312d893
working version of conversation with lora + load 8bit + hf hub
Feb 13, 2024
afc2e93
Tested llama2-7b finetuning
Feb 22, 2024
22ee5c7
pypgrx first working version
Feb 27, 2024
97d455d
refactoring finetuning code to add callbacks
santiatpml Feb 27, 2024
b700944
fixed merge conflicts
santiatpml Mar 5, 2024
65d2f8b
Refactored finetuning + conversation + pgml callbacks
Mar 2, 2024
5f1b5f4
removed wandb dependency
Mar 4, 2024
08084bf
removed local pypgrx from requirements
Mar 4, 2024
dc0c6ee
removed maturin from requirements
Mar 4, 2024
421af8f
removed flash attn
Mar 4, 2024
4bbca96
Added indent for info display
Mar 5, 2024
3db857c
Updated readme with LLM fine-tuning for text classification
santiatpml Mar 7, 2024
7cbee43
README updates
santiatpml Mar 7, 2024
9284cf1
Added a tutorial for 9 classes - draft 1
santiatpml Mar 8, 2024
66c65c8
README updates
santiatpml Mar 8, 2024
5759ee3
Moved python functions (#1374)
SilasMarvin Mar 18, 2024
b539168
README updates
santiatpml Mar 19, 2024
31215b8
migrations and removed pypgrx
santiatpml Mar 20, 2024
dae6b74
Added r_log to take log level and message
santiatpml Mar 20, 2024
dae5ffc
Updated version and requirements
Mar 22, 2024
435f5bd
Changed version 2.8.3
Mar 22, 2024
aeb2683
README updates for conversation task fine-tuning using lora
santiatpml Mar 22, 2024
e5221cc
minor readme updates
santiatpml Mar 26, 2024
6db147e
added new line
santiatpml Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Working e2e finetunig for two tasks
  • Loading branch information
Santi Adavani authored and santiatpml committed Mar 5, 2024
commit 097a8cf00116d2e1f14dd0f8046024d1c4c0b33c
28 changes: 15 additions & 13 deletions pgml-extension/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -906,25 +906,27 @@ fn tune(
LIMIT 1;",
vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())],
);

let mut deploy = true;
match automatic_deploy {
// Deploy only if metrics are better than previous model.
Some(true) | None => {
if let Ok(Some(deployed_metrics)) = deployed_metrics {
let deployed_metrics = deployed_metrics.0.as_object().unwrap();
if project.task.value_is_better(
deployed_metrics
.get(&project.task.default_target_metric())
.unwrap()
.as_f64()
.unwrap(),
new_metrics
.get(&project.task.default_target_metric())
.unwrap()
.as_f64()
.unwrap(),
) {

let deployed_value = deployed_metrics
.get(&project.task.default_target_metric())
.and_then(|value| value.as_f64())
.unwrap_or_default(); // Default to 0.0 if the key is not present or conversion fails

// Get the value for the default target metric from new_metrics or provide a default value
let new_value = new_metrics
.get(&project.task.default_target_metric())
.and_then(|value| value.as_f64())
.unwrap_or_default(); // Default to 0.0 if the key is not present or conversion fails


if project.task.value_is_better(deployed_value, new_value){
deploy = false;
}
}
Expand Down
61 changes: 19 additions & 42 deletions pgml-extension/src/bindings/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,26 +1052,14 @@ def tokenize_function(example):
# Generate tokens
train_tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
test_tokenized_datasets = test_dataset.map(tokenize_function, batched=True)
log.info("Tokenization done")
log.info("Train dataset")
log.info(train_tokenized_datasets[0:2])
log.info("Test dataset")
log.info(test_tokenized_datasets[0:2])

# Data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Training Args
log.info("Training args setup started path=%s"%path)
training_args=TrainingArguments(output_dir=path, logging_dir=path, **hyperparams["training_args"])

log.info("Trainer setup done")
# Trainer
log.info(model)
log.info(training_args)
log.info(train_tokenized_datasets)
log.info(test_tokenized_datasets)
log.info(tokenizer)
log.info(data_collator)
try:
trainer = Trainer(
model=model,
Expand All @@ -1083,15 +1071,16 @@ def tokenize_function(example):
)
except Exception as e:
log.error(e)
log.info("Training started")


# Train
trainer.train()

# Save model
trainer.save_model()

metrics = {"loss" : 0.0}
# TODO: compute real metrics
metrics = {"loss" : 0.0, "f1": 1.0}

return metrics

def finetune_text_pair_classification(task, hyperparams, path, text1_train, text1_test, text2_train, text2_test, class_train, class_test):
Expand Down Expand Up @@ -1147,42 +1136,30 @@ def tokenize_function(example):
# Generate tokens
train_tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
test_tokenized_datasets = test_dataset.map(tokenize_function, batched=True)
log.info("Tokenization done")
log.info("Train dataset")
log.info(train_tokenized_datasets[0:2])
log.info("Test dataset")
log.info(test_tokenized_datasets[0:2])

# Data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Training Args
log.info("Training args setup started path=%s"%path)
training_args=TrainingArguments(output_dir=path, logging_dir=path, **hyperparams["training_args"])
log.info("Trainer setup done")

# Trainer
log.info(model)
log.info(training_args)
log.info(train_tokenized_datasets)
log.info(test_tokenized_datasets)
log.info(tokenizer)
log.info(data_collator)
try:
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_tokenized_datasets,
eval_dataset=test_tokenized_datasets,
tokenizer=tokenizer,
data_collator=data_collator,
)
except Exception as e:
log.error(e)
log.info("Training started")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_tokenized_datasets,
eval_dataset=test_tokenized_datasets,
tokenizer=tokenizer,
data_collator=data_collator,
)

# Train
trainer.train()

# Save model
trainer.save_model()
metrics = {"loss" : 0.0}

# TODO: Get real metrics
metrics = {"loss" : 0.0, "f1": 1.0}

return metrics
48 changes: 24 additions & 24 deletions pgml-extension/src/orm/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,6 @@ impl Model {

model.metrics = Some(JsonB(json!(metrics)));
info!("Metrics: {:?}", &metrics);
// let metrics = match transformers::finetune(&project.task, dataset, &model.hyperparams, &path) {
// Ok(metrics) => metrics,
// Err(e) => error!("{e}"),
// };
// model.metrics = Some(JsonB(json!(metrics)));
// info!("Metrics: {:?}", &metrics);

Spi::get_one_with_args::<i64>(
"UPDATE pgml.models SET hyperparams = $1, metrics = $2 WHERE id = $3 RETURNING id",
Expand All @@ -266,26 +260,31 @@ impl Model {
.unwrap();

// Save the bindings.
/*for entry in std::fs::read_dir(&path).unwrap() {
for entry in std::fs::read_dir(&path).unwrap() {
let path = entry.unwrap().path();
let bytes = std::fs::read(&path).unwrap();
for (i, chunk) in bytes.chunks(100_000_000).enumerate() {
Spi::get_one_with_args::<i64>(
"INSERT INTO pgml.files (model_id, path, part, data) VALUES($1, $2, $3, $4) RETURNING id",
vec![
(PgBuiltInOids::INT8OID.oid(), model.id.into_datum()),
(
PgBuiltInOids::TEXTOID.oid(),
path.file_name().unwrap().to_str().into_datum(),
),
(PgBuiltInOids::INT8OID.oid(), (i as i64).into_datum()),
(PgBuiltInOids::BYTEAOID.oid(), chunk.into_datum()),
],
)
.unwrap();
}
}*/

if path.is_file() {

let bytes = std::fs::read(&path).unwrap();

for (i, chunk) in bytes.chunks(100_000_000).enumerate() {
Spi::get_one_with_args::<i64>(
"INSERT INTO pgml.files (model_id, path, part, data) VALUES($1, $2, $3, $4) RETURNING id",
vec![
(PgBuiltInOids::INT8OID.oid(), model.id.into_datum()),
(
PgBuiltInOids::TEXTOID.oid(),
path.file_name().unwrap().to_str().into_datum(),
),
(PgBuiltInOids::INT8OID.oid(), (i as i64).into_datum()),
(PgBuiltInOids::BYTEAOID.oid(), chunk.into_datum()),
],
)
.unwrap();
}
}
}

Spi::run_with_args(
"UPDATE pgml.models SET status = $1::pgml.status WHERE id = $2",
Some(vec![
Expand All @@ -297,6 +296,7 @@ impl Model {
]),
)
.unwrap();

model
}

Expand Down