🌐 AI搜索 & 代理 主页
Skip to content

Commit 600bbfe

Browse files
committed
Updated load_dataset to be resistent to bad columns
1 parent 41f36f0 commit 600bbfe

File tree

1 file changed

+30
-19
lines changed
  • pgml-extension/src/bindings/transformers

1 file changed

+30
-19
lines changed

pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -434,54 +434,65 @@ pub fn load_dataset(
434434
Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#))?;
435435
let insert = format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#);
436436
for i in 0..num_rows {
437+
let mut skip = false;
437438
let mut row = Vec::with_capacity(num_cols);
438439
for (name, values) in data {
439440
let value = values
440441
.as_array()
441442
.ok_or_else(|| anyhow!("expected {values} to be an array"))?
442443
.get(i)
443444
.ok_or_else(|| anyhow!("invalid index {i} for {values}"))?;
444-
match types
445+
let (ty, datum) = match types
445446
.get(name)
446447
.ok_or_else(|| anyhow!("{types:?} expected to have key {name}"))?
447448
.as_str()
448449
.ok_or_else(|| anyhow!("json field {name} expected to be string"))?
449450
{
450-
"string" => row.push((
451+
"string" => (
451452
PgBuiltInOids::TEXTOID.oid(),
452453
value
453454
.as_str()
454-
.ok_or_else(|| anyhow!("expected {value} to be string"))?
455-
.into_datum(),
456-
)),
457-
"dict" | "list" => row.push((PgBuiltInOids::JSONBOID.oid(), JsonB(value.clone()).into_datum())),
458-
"int64" | "int32" | "int16" => row.push((
455+
.map(IntoDatum::into_datum)
456+
.ok_or_else(|| anyhow!("expected column {name} with {value} to be string")),
457+
),
458+
"dict" | "list" => (PgBuiltInOids::JSONBOID.oid(), Ok(JsonB(value.clone()).into_datum())),
459+
"int64" | "int32" | "int16" => (
459460
PgBuiltInOids::INT8OID.oid(),
460461
value
461462
.as_i64()
462-
.ok_or_else(|| anyhow!("expected {value} to be i64"))?
463-
.into_datum(),
464-
)),
465-
"float64" | "float32" | "float16" => row.push((
463+
.map(IntoDatum::into_datum)
464+
.ok_or_else(|| anyhow!("expected column {name} with {value} to be i64")),
465+
),
466+
"float64" | "float32" | "float16" => (
466467
PgBuiltInOids::FLOAT8OID.oid(),
467468
value
468469
.as_f64()
469-
.ok_or_else(|| anyhow!("expected {value} to be f64"))?
470-
.into_datum(),
471-
)),
472-
"bool" => row.push((
470+
.map(IntoDatum::into_datum)
471+
.ok_or_else(|| anyhow!("expected column {name} with {value} to be f64")),
472+
),
473+
"bool" => (
473474
PgBuiltInOids::BOOLOID.oid(),
474475
value
475476
.as_bool()
476-
.ok_or_else(|| anyhow!("expected {value} to be bool"))?
477-
.into_datum(),
478-
)),
477+
.map(IntoDatum::into_datum)
478+
.ok_or_else(|| anyhow!("expected column {name} with {value} to be bool")),
479+
),
479480
type_ => {
480481
bail!("unhandled dataset value type while reading dataset: {value:?} {type_:?}")
481482
}
483+
};
484+
match datum {
485+
Ok(datum) => row.push((ty, datum)),
486+
Err(e) => {
487+
warning!("failed to convert dataset value to datum while reading dataset: {e}");
488+
skip = true;
489+
break;
490+
}
482491
}
483492
}
484-
Spi::run_with_args(&insert, Some(row))?
493+
if !skip {
494+
Spi::run_with_args(&insert, Some(row))?
495+
}
485496
}
486497

487498
Ok(num_rows)

0 commit comments

Comments
 (0)