🌐 AI搜索 & 代理 主页
Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
OpenSourceAI support
  • Loading branch information
santiadavani committed Dec 5, 2023
commit df3121d77a03a29294c5909df556f4943cd3d973
66 changes: 51 additions & 15 deletions pgml-apps/pgml-chat/pgml_chat/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from pgml import Collection, Model, Splitter, Pipeline, migrate, init_logger, Builtins
from pgml import Collection, Model, Splitter, Pipeline, migrate, init_logger, Builtins, OpenSourceAI
import logging
from rich.logging import RichHandler
from rich.progress import track
Expand All @@ -9,7 +9,7 @@
import glob
import argparse
from time import time
import openai
from openai import OpenAI
import signal
from uuid import uuid4
import pendulum
Expand Down Expand Up @@ -110,6 +110,12 @@ def handler(signum, frame):
help="Persona of the bot",
)

parser.add_argument(
"--chat_completion_model",
dest="chat_completion_model",
type="str",
default="gpt-3.5-turbo-16k",
)

args = parser.parse_args()

Expand Down Expand Up @@ -160,6 +166,8 @@ def handler(signum, frame):
chat_history_collection_name + "_pipeline", model, splitter
)

chat_completion_model = args.chat_completion_model

query_params_instruction = (
"Represent the %s question for retrieving supporting documents: " % (bot_topic)
)
Expand Down Expand Up @@ -203,6 +211,7 @@ def handler(signum, frame):

Helpful Answer:"""


openai_api_key = os.environ.get("OPENAI_API_KEY")

system_prompt_document = [
Expand All @@ -215,6 +224,20 @@ def handler(signum, frame):
}
]

def model_type(chat_completion_model: str):
model_type = "opensourceai"
try:
client = OpenAI(api_key=openai_api_key)
models = client.models.list().data
for model in models:
if model["id"] == chat_completion_model:
model_type = "openai"
break
except Exception as e:
print(e)

return model_type

async def upsert_documents(folder: str) -> int:
log.info("Scanning " + folder + " for markdown files")
md_files = []
Expand Down Expand Up @@ -335,19 +358,32 @@ async def generate_chat_response(
async def generate_response(
messages, openai_api_key, temperature=0.7, max_tokens=256, top_p=0.9
):
openai.api_key = openai_api_key
log.debug("Generating response from OpenAI API: " + str(messages))
response = openai.ChatCompletion.create(
# model="gpt-3.5-turbo-16k",
model="gpt-4",
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=0,
presence_penalty=0,
)
return response["choices"][0]["message"]["content"]
model_type = model_type(chat_completion_model)
if model_type == "openai":
client = OpenAI(api_key=openai_api_key)
log.debug("Generating response from OpenAI API: " + str(messages))
response = client.chat.completions.create(
# model="gpt-3.5-turbo-16k",
model=chat_completion_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=0,
presence_penalty=0,
)
output = response.choices[0].message.content
else:
client = OpenSourceAI(database_url=database_url)
log.debug("Generating response from OpenSourceAI API: " + str(messages))
response = client.chat_completions_create(
model=chat_completion_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens )
output = response["choices"][0]["message"]["content"]

return output


async def ingest_documents(folder: str):
Expand Down
Loading