@@ -6,19 +6,21 @@ use rust_bridge::{alias, alias_methods};
66use sea_query:: Alias ;
77use sea_query:: { Expr , NullOrdering , Order , PostgresQueryBuilder , Query } ;
88use sea_query_binder:: SqlxBinder ;
9- use serde_json:: json;
10- use sqlx:: Executor ;
9+ use serde_json:: { json, Value } ;
1110use sqlx:: PgConnection ;
11+ use sqlx:: { Executor , Pool , Postgres } ;
1212use std:: borrow:: Cow ;
1313use std:: collections:: HashMap ;
1414use std:: path:: Path ;
1515use std:: time:: SystemTime ;
1616use std:: time:: UNIX_EPOCH ;
17+ use tokio:: task:: JoinSet ;
1718use tracing:: { instrument, warn} ;
1819use walkdir:: WalkDir ;
1920
2021use crate :: debug_sqlx_query;
2122use crate :: filter_builder:: FilterBuilder ;
23+ use crate :: pipeline:: FieldAction ;
2224use crate :: search_query_builder:: build_search_query;
2325use crate :: vector_search_query_builder:: build_vector_search_query;
2426use crate :: {
@@ -496,28 +498,80 @@ impl Collection {
496498 // -> Insert the document
497499 // -> Foreach pipeline check if we need to resync the document and if so sync the document
498500 // -> Commit the transaction
501+ let mut args = args. unwrap_or_default ( ) ;
502+ let args = args. as_object_mut ( ) . context ( "args must be a JSON object" ) ?;
503+
499504 self . verify_in_database ( false ) . await ?;
500505 let mut pipelines = self . get_pipelines ( ) . await ?;
501506
502507 let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
503508
504- let mut parsed_schemas = vec ! [ ] ;
505509 let project_info = & self . database_data . as_ref ( ) . unwrap ( ) . project_info ;
510+ let mut parsed_schemas = vec ! [ ] ;
506511 for pipeline in & mut pipelines {
507512 let parsed_schema = pipeline
508513 . get_parsed_schema ( project_info, & pool)
509514 . await
510515 . expect ( "Error getting parsed schema for pipeline" ) ;
511516 parsed_schemas. push ( parsed_schema) ;
512517 }
513- let mut pipelines: Vec < ( Pipeline , _ ) > = pipelines. into_iter ( ) . zip ( parsed_schemas) . collect ( ) ;
518+ let pipelines: Vec < ( Pipeline , HashMap < String , FieldAction > ) > =
519+ pipelines. into_iter ( ) . zip ( parsed_schemas) . collect ( ) ;
514520
515- let args = args. unwrap_or_default ( ) ;
516- let args = args. as_object ( ) . context ( "args must be a JSON object" ) ?;
521+ let batch_size = args
522+ . remove ( "batch_size" )
523+ . map ( |x| x. try_to_u64 ( ) )
524+ . unwrap_or ( Ok ( 100 ) ) ?;
525+
526+ let parallel_batches = args
527+ . get ( "parallel_batches" )
528+ . map ( |x| x. try_to_u64 ( ) )
529+ . unwrap_or ( Ok ( 1 ) ) ? as usize ;
517530
518531 let progress_bar = utils:: default_progress_bar ( documents. len ( ) as u64 ) ;
519532 progress_bar. println ( "Upserting Documents..." ) ;
520533
534+ let mut set = JoinSet :: new ( ) ;
535+ for batch in documents. chunks ( batch_size as usize ) {
536+ if set. len ( ) < parallel_batches {
537+ let local_self = self . clone ( ) ;
538+ let local_batch = batch. to_owned ( ) ;
539+ let local_args = args. clone ( ) ;
540+ let local_pipelines = pipelines. clone ( ) ;
541+ let local_pool = pool. clone ( ) ;
542+ set. spawn ( async move {
543+ local_self
544+ . _upsert_documents ( local_batch, local_args, local_pipelines, local_pool)
545+ . await
546+ } ) ;
547+ } else {
548+ if let Some ( res) = set. join_next ( ) . await {
549+ res??;
550+ progress_bar. inc ( batch_size) ;
551+ }
552+ }
553+ }
554+
555+ while let Some ( res) = set. join_next ( ) . await {
556+ res??;
557+ progress_bar. inc ( batch_size) ;
558+ }
559+
560+ progress_bar. println ( "Done Upserting Documents\n " ) ;
561+ progress_bar. finish ( ) ;
562+
563+ Ok ( ( ) )
564+ }
565+
566+ async fn _upsert_documents (
567+ self ,
568+ batch : Vec < Json > ,
569+ args : serde_json:: Map < String , Value > ,
570+ mut pipelines : Vec < ( Pipeline , HashMap < String , FieldAction > ) > ,
571+ pool : Pool < Postgres > ,
572+ ) -> anyhow:: Result < ( ) > {
573+ let project_info = & self . database_data . as_ref ( ) . unwrap ( ) . project_info ;
574+
521575 let query = if args
522576 . get ( "merge" )
523577 . map ( |v| v. as_bool ( ) . unwrap_or ( false ) )
@@ -539,111 +593,99 @@ impl Collection {
539593 )
540594 } ;
541595
542- let batch_size = args
543- . get ( "batch_size" )
544- . map ( TryToNumeric :: try_to_u64)
545- . unwrap_or ( Ok ( 100 ) ) ?;
546-
547- for batch in documents. chunks ( batch_size as usize ) {
548- let mut transaction = pool. begin ( ) . await ?;
549-
550- let mut query_values = String :: new ( ) ;
551- let mut binding_parameter_counter = 1 ;
552- for _ in 0 ..batch. len ( ) {
553- query_values = format ! (
554- "{query_values}, (${}, ${}, ${})" ,
555- binding_parameter_counter,
556- binding_parameter_counter + 1 ,
557- binding_parameter_counter + 2
558- ) ;
559- binding_parameter_counter += 3 ;
560- }
596+ let mut transaction = pool. begin ( ) . await ?;
561597
562- let query = query. replace (
563- "{values_parameters}" ,
564- & query_values. chars ( ) . skip ( 1 ) . collect :: < String > ( ) ,
565- ) ;
566- let query = query. replace (
567- "{binding_parameter}" ,
568- & format ! ( "${binding_parameter_counter}" ) ,
598+ let mut query_values = String :: new ( ) ;
599+ let mut binding_parameter_counter = 1 ;
600+ for _ in 0 ..batch. len ( ) {
601+ query_values = format ! (
602+ "{query_values}, (${}, ${}, ${})" ,
603+ binding_parameter_counter,
604+ binding_parameter_counter + 1 ,
605+ binding_parameter_counter + 2
569606 ) ;
607+ binding_parameter_counter += 3 ;
608+ }
570609
571- let mut query = sqlx:: query_as ( & query) ;
572-
573- let mut source_uuids = vec ! [ ] ;
574- for document in batch {
575- let id = document
576- . get ( "id" )
577- . context ( "`id` must be a key in document" ) ?
578- . to_string ( ) ;
579- let md5_digest = md5:: compute ( id. as_bytes ( ) ) ;
580- let source_uuid = uuid:: Uuid :: from_slice ( & md5_digest. 0 ) ?;
581- source_uuids. push ( source_uuid) ;
582-
583- let start = SystemTime :: now ( ) ;
584- let timestamp = start
585- . duration_since ( UNIX_EPOCH )
586- . expect ( "Time went backwards" )
587- . as_millis ( ) ;
588-
589- let versions: HashMap < String , serde_json:: Value > = document
590- . as_object ( )
591- . context ( "document must be an object" ) ?
592- . iter ( )
593- . try_fold ( HashMap :: new ( ) , |mut acc, ( key, value) | {
594- let md5_digest = md5:: compute ( serde_json:: to_string ( value) ?. as_bytes ( ) ) ;
595- let md5_digest = format ! ( "{md5_digest:x}" ) ;
596- acc. insert (
597- key. to_owned ( ) ,
598- serde_json:: json!( {
599- "last_updated" : timestamp,
600- "md5" : md5_digest
601- } ) ,
602- ) ;
603- anyhow:: Ok ( acc)
604- } ) ?;
605- let versions = serde_json:: to_value ( versions) ?;
606-
607- query = query. bind ( source_uuid) . bind ( document) . bind ( versions) ;
608- }
610+ let query = query. replace (
611+ "{values_parameters}" ,
612+ & query_values. chars ( ) . skip ( 1 ) . collect :: < String > ( ) ,
613+ ) ;
614+ let query = query. replace (
615+ "{binding_parameter}" ,
616+ & format ! ( "${binding_parameter_counter}" ) ,
617+ ) ;
609618
610- let results: Vec < ( i64 , Option < Json > ) > = query
611- . bind ( source_uuids)
612- . fetch_all ( & mut * transaction)
613- . await ?;
619+ let mut query = sqlx:: query_as ( & query) ;
620+
621+ let mut source_uuids = vec ! [ ] ;
622+ for document in & batch {
623+ let id = document
624+ . get ( "id" )
625+ . context ( "`id` must be a key in document" ) ?
626+ . to_string ( ) ;
627+ let md5_digest = md5:: compute ( id. as_bytes ( ) ) ;
628+ let source_uuid = uuid:: Uuid :: from_slice ( & md5_digest. 0 ) ?;
629+ source_uuids. push ( source_uuid) ;
630+
631+ let start = SystemTime :: now ( ) ;
632+ let timestamp = start
633+ . duration_since ( UNIX_EPOCH )
634+ . expect ( "Time went backwards" )
635+ . as_millis ( ) ;
636+
637+ let versions: HashMap < String , serde_json:: Value > = document
638+ . as_object ( )
639+ . context ( "document must be an object" ) ?
640+ . iter ( )
641+ . try_fold ( HashMap :: new ( ) , |mut acc, ( key, value) | {
642+ let md5_digest = md5:: compute ( serde_json:: to_string ( value) ?. as_bytes ( ) ) ;
643+ let md5_digest = format ! ( "{md5_digest:x}" ) ;
644+ acc. insert (
645+ key. to_owned ( ) ,
646+ serde_json:: json!( {
647+ "last_updated" : timestamp,
648+ "md5" : md5_digest
649+ } ) ,
650+ ) ;
651+ anyhow:: Ok ( acc)
652+ } ) ?;
653+ let versions = serde_json:: to_value ( versions) ?;
614654
615- let dp: Vec < ( i64 , Json , Option < Json > ) > = results
616- . into_iter ( )
617- . zip ( batch)
618- . map ( |( ( id, previous_document) , document) | {
619- ( id, document. to_owned ( ) , previous_document)
655+ query = query. bind ( source_uuid) . bind ( document) . bind ( versions) ;
656+ }
657+
658+ let results: Vec < ( i64 , Option < Json > ) > = query
659+ . bind ( source_uuids)
660+ . fetch_all ( & mut * transaction)
661+ . await ?;
662+
663+ let dp: Vec < ( i64 , Json , Option < Json > ) > = results
664+ . into_iter ( )
665+ . zip ( batch)
666+ . map ( |( ( id, previous_document) , document) | ( id, document. to_owned ( ) , previous_document) )
667+ . collect ( ) ;
668+
669+ for ( pipeline, parsed_schema) in & mut pipelines {
670+ let ids_to_run_on: Vec < i64 > = dp
671+ . iter ( )
672+ . filter ( |( _, document, previous_document) | match previous_document {
673+ Some ( previous_document) => parsed_schema
674+ . iter ( )
675+ . any ( |( key, _) | document[ key] != previous_document[ key] ) ,
676+ None => true ,
620677 } )
678+ . map ( |( document_id, _, _) | * document_id)
621679 . collect ( ) ;
622-
623- for ( pipeline, parsed_schema) in & mut pipelines {
624- let ids_to_run_on: Vec < i64 > = dp
625- . iter ( )
626- . filter ( |( _, document, previous_document) | match previous_document {
627- Some ( previous_document) => parsed_schema
628- . iter ( )
629- . any ( |( key, _) | document[ key] != previous_document[ key] ) ,
630- None => true ,
631- } )
632- . map ( |( document_id, _, _) | * document_id)
633- . collect ( ) ;
634- if !ids_to_run_on. is_empty ( ) {
635- pipeline
636- . sync_documents ( ids_to_run_on, project_info, & mut transaction)
637- . await
638- . expect ( "Failed to execute pipeline" ) ;
639- }
680+ if !ids_to_run_on. is_empty ( ) {
681+ pipeline
682+ . sync_documents ( ids_to_run_on, project_info, & mut transaction)
683+ . await
684+ . expect ( "Failed to execute pipeline" ) ;
640685 }
641-
642- transaction. commit ( ) . await ?;
643- progress_bar. inc ( batch_size) ;
644686 }
645- progress_bar . println ( "Done Upserting Documents \n " ) ;
646- progress_bar . finish ( ) ;
687+
688+ transaction . commit ( ) . await ? ;
647689 Ok ( ( ) )
648690 }
649691
0 commit comments