11import asyncio
2- from pgml import Database
2+ from pgml import Collection , Model , Splitter , Pipeline
33import logging
44from rich .logging import RichHandler
55from rich .progress import track
@@ -77,26 +77,28 @@ def handler(signum, frame):
7777
7878
7979# The code is using the `argparse` module to parse command line arguments.
80- collection_name = args .collection_name
80+ collection = Collection ( args .collection_name )
8181stage = args .stage
8282chat_interface = args .chat_interface
8383
8484# The above code is retrieving environment variables and assigning their values to various variables.
8585database_url = os .environ .get ("DATABASE_URL" )
86- db = Database (database_url )
87- splitter = os .environ .get ("SPLITTER" , "recursive_character" )
86+ splitter_name = os .environ .get ("SPLITTER" , "recursive_character" )
8887splitter_params = os .environ .get (
8988 "SPLITTER_PARAMS" , {"chunk_size" : 1500 , "chunk_overlap" : 40 }
9089)
91- model = os .environ .get ("MODEL" , "intfloat/e5-small" )
90+ splitter = Splitter (splitter_name , splitter_params )
91+ model_name = os .environ .get ("MODEL" , "intfloat/e5-small" )
9292model_params = ast .literal_eval (os .environ .get ("MODEL_PARAMS" , {}))
93+ model = Model (model_name , "pgml" , model_params )
94+ pipeline = Pipeline (args .collection_name + "_pipeline" , model , splitter )
9395query_params = ast .literal_eval (os .environ .get ("QUERY_PARAMS" , {}))
9496system_prompt = os .environ .get ("SYSTEM_PROMPT" )
9597base_prompt = os .environ .get ("BASE_PROMPT" )
9698openai_api_key = os .environ .get ("OPENAI_API_KEY" )
9799
98100
99- async def upsert_documents (db : Database , collection_name : str , folder : str ) -> int :
101+ async def upsert_documents (folder : str ) -> int :
100102 log .info ("Scanning " + folder + " for markdown files" )
101103 md_files = []
102104 # root_dir needs a trailing slash (i.e. /root/dir/)
@@ -107,100 +109,14 @@ async def upsert_documents(db: Database, collection_name: str, folder: str) -> i
107109 documents = []
108110 for md_file in track (md_files , description = "Extracting text from markdown" ):
109111 with open (md_file , "r" ) as f :
110- documents .append ({"text" : f .read (), "filename " : md_file })
112+ documents .append ({"text" : f .read (), "id " : md_file })
111113
112114 log .info ("Upserting documents into database" )
113- collection = await db .create_or_get_collection (collection_name )
114115 await collection .upsert_documents (documents )
115116
116117 return len (md_files )
117118
118119
119- async def generate_chunks (
120- db : Database ,
121- collection_name : str ,
122- splitter : str = "recursive_character" ,
123- splitter_params : dict = {"chunk_size" : 1500 , "chunk_overlap" : 40 },
124- ) -> int :
125- """
126- The function `generate_chunks` generates chunks for a given collection in a database and returns the
127- count of chunks created.
128-
129- :param db: The `db` parameter is an instance of a database connection or client. It is used to
130- interact with the database and perform operations such as creating collections, executing queries,
131- and fetching results
132- :type db: Database
133- :param collection_name: The `collection_name` parameter is a string that represents the name of the
134- collection in the database. It is used to create or get the collection and perform operations on it
135- :type collection_name: str
136- :return: The function `generate_chunks` returns an integer, which represents the count of chunks
137- generated in the specified collection.
138- """
139- log .info ("Generating chunks" )
140- collection = await db .create_or_get_collection (collection_name )
141- splitter_id = await collection .register_text_splitter (splitter , splitter_params )
142- query_string = """SELECT count(*) from {collection_name}.chunks""" .format (
143- collection_name = collection_name
144- )
145- results = await db .query (query_string ).fetch_all ()
146- start_chunks = results [0 ]["count" ]
147- log .info ("Starting chunk count: " + str (start_chunks ))
148- await collection .generate_chunks (splitter_id )
149- results = await db .query (query_string ).fetch_all ()
150- log .info ("Ending chunk count: " + str (results [0 ]["count" ]))
151- return results [0 ]["count" ] - start_chunks
152-
153-
154- async def generate_embeddings (
155- db : Database ,
156- collection_name : str ,
157- splitter : str = "recursive_character" ,
158- splitter_params : dict = {"chunk_size" : 1500 , "chunk_overlap" : 40 },
159- model : str = "intfloat/e5-small" ,
160- model_params : dict = {},
161- ) -> int :
162- """
163- The `generate_embeddings` function generates embeddings for text data using a specified model and
164- splitter.
165-
166- :param db: The `db` parameter is an instance of a database object. It is used to interact with the
167- database and perform operations such as creating or getting a collection, registering a text
168- splitter, registering a model, and generating embeddings
169- :type db: Database
170- :param collection_name: The `collection_name` parameter is a string that represents the name of the
171- collection in the database where the embeddings will be generated
172- :type collection_name: str
173- :param splitter: The `splitter` parameter is used to specify the text splitting method to be used
174- during the embedding generation process. In this case, the value is set to "recursive_character",
175- which suggests that the text will be split into chunks based on recursive character splitting,
176- defaults to recursive_character
177- :type splitter: str (optional)
178- :param splitter_params: The `splitter_params` parameter is a dictionary that contains the parameters
179- for the text splitter. In this case, the `splitter_params` dictionary has two keys:
180- :type splitter_params: dict
181- :param model: The `model` parameter is the name or identifier of the language model that will be
182- used to generate the embeddings. In this case, the model is specified as "intfloat/e5-small",
183- defaults to intfloat/e5-small
184- :type model: str (optional)
185- :param model_params: The `model_params` parameter is a dictionary that allows you to specify
186- additional parameters for the model. These parameters can be used to customize the behavior of the
187- model during the embedding generation process. The specific parameters that can be included in the
188- `model_params` dictionary will depend on the specific model you are
189- :type model_params: dict
190- :return: an integer value of 0.
191- """
192- log .info ("Generating embeddings" )
193- collection = await db .create_or_get_collection (collection_name )
194- splitter_id = await collection .register_text_splitter (splitter , splitter_params )
195- model_id = await collection .register_model ("embedding" , model , model_params )
196- log .info ("Splitter ID: " + str (splitter_id ))
197- start = time ()
198- await collection .generate_embeddings (model_id , splitter_id )
199- log .info ("Embeddings generated in %0.3f seconds" % (time () - start ))
200-
201- return 0
202-
203-
204120async def generate_response (
205121 messages , openai_api_key , temperature = 0.7 , max_tokens = 256 , top_p = 0.9
206122):
@@ -217,44 +133,20 @@ async def generate_response(
217133 return response ["choices" ][0 ]["message" ]["content" ]
218134
219135
220- async def ingest_documents (
221- db : Database ,
222- collection_name : str ,
223- folder : str ,
224- splitter : str ,
225- splitter_params : dict ,
226- model : str ,
227- model_params : dict ,
228- ):
229- total_docs = await upsert_documents (db , collection_name , folder = folder )
230- total_chunks = await generate_chunks (
231- db , collection_name , splitter = splitter , splitter_params = splitter_params
232- )
233- log .info (
234- "Total documents: " + str (total_docs ) + " Total chunks: " + str (total_chunks )
235- )
236-
237- await generate_embeddings (
238- db ,
239- collection_name ,
240- splitter = splitter ,
241- splitter_params = splitter_params ,
242- model = model ,
243- model_params = model_params ,
244- )
136+ async def ingest_documents (folder : str ):
137+ # Add the pipeline to the collection, does nothing if we have already added it
138+ await collection .add_pipeline (pipeline )
139+ # This will upsert, chunk, and embed the contents in the folder
140+ total_docs = await upsert_documents (folder )
141+ log .info ("Total documents: " + str (total_docs ))
245142
246143
247144async def get_prompt (user_input : str = "" ):
248- collection = await db .create_or_get_collection (collection_name )
249- model_id = await collection .register_model ("embedding" , model , model_params )
250- splitter_id = await collection .register_text_splitter (splitter , splitter_params )
251- log .info ("Model id: " + str (model_id ) + " Splitter id: " + str (splitter_id ))
252- vector_results = await collection .vector_search (
253- user_input ,
254- model_id = model_id ,
255- splitter_id = splitter_id ,
256- top_k = 2 ,
257- query_params = query_params ,
145+ vector_results = (
146+ await collection .query ()
147+ .vector_recall (user_input , pipeline , query_params )
148+ .limit (2 )
149+ .fetch_all ()
258150 )
259151 log .info (vector_results )
260152 context = ""
@@ -322,10 +214,12 @@ async def message_hello(message, say):
322214intents .message_content = True
323215client = discord .Client (intents = intents )
324216
217+
325218@client .event
326219async def on_ready ():
327220 print (f"We have logged in as { client .user } " )
328221
222+
329223@client .event
330224async def on_message (message ):
331225 bot_mention = f"<@{ client .user .id } >"
@@ -351,15 +245,7 @@ async def run():
351245
352246 if stage == "ingest" :
353247 root_dir = args .root_dir
354- await ingest_documents (
355- db ,
356- collection_name ,
357- root_dir ,
358- splitter ,
359- splitter_params ,
360- model ,
361- model_params ,
362- )
248+ await ingest_documents (root_dir )
363249
364250 elif stage == "chat" :
365251 if chat_interface == "cli" :
@@ -369,7 +255,12 @@ async def run():
369255
370256
371257def main ():
372- if stage == "chat" and chat_interface == "discord" and os .environ .get ("DISCORD_BOT_TOKEN" ):
258+ if (
259+ stage == "chat"
260+ and chat_interface == "discord"
261+ and os .environ .get ("DISCORD_BOT_TOKEN" )
262+ ):
373263 client .run (os .environ ["DISCORD_BOT_TOKEN" ])
374264 else :
375265 asyncio .run (run ())
266+ main ()
0 commit comments