🌐 AI搜索 & 代理 主页
Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Removed Instructor
  • Loading branch information
SilasMarvin committed May 10, 2024
commit 139898d554e8467a6c5772902ece669325d5e480
29 changes: 10 additions & 19 deletions pgml-extension/src/bindings/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from datetime import datetime

import datasets
from InstructorEmbedding import INSTRUCTOR
import numpy
import orjson
from rouge import Rouge
Expand Down Expand Up @@ -502,23 +501,17 @@ def transform(task, args, inputs, stream=False):


def create_embedding(transformer):
instructor = transformer.startswith("hkunlp/instructor")
klass = INSTRUCTOR if instructor else SentenceTransformer
return klass(transformer)
return SentenceTransformer(transformer)


def embed_using(model, transformer, inputs, kwargs):
if isinstance(kwargs, str):
kwargs = orjson.loads(kwargs)

instructor = transformer.startswith("hkunlp/instructor")
if instructor:
texts_with_instructions = []
if instructor and "instruction" in kwargs:
instruction = kwargs.pop("instruction")
for text in inputs:
texts_with_instructions.append([instruction, text])

inputs = texts_with_instructions
kwargs["prompt"] = instruction

return model.encode(inputs, **kwargs)

Expand Down Expand Up @@ -1029,7 +1022,6 @@ def __init__(
path: str,
hyperparameters: dict,
) -> None:

# initialize class variables
self.project_id = project_id
self.model_id = model_id
Expand Down Expand Up @@ -1100,8 +1092,9 @@ def print_number_of_trainable_model_parameters(self, model):
# Calculate and print the number and percentage of trainable parameters
r_log("info", f"Trainable model parameters: {trainable_model_params}")
r_log("info", f"All model parameters: {all_model_params}")
r_log("info",
f"Percentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"
r_log(
"info",
f"Percentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%",
)

def tokenize_function(self):
Expand Down Expand Up @@ -1396,23 +1389,22 @@ def __init__(
"bias": "none",
"task_type": "CAUSAL_LM",
}
r_log("info",
r_log(
"info",
"LoRA configuration are not set. Using default parameters"
+ json.dumps(self.lora_config_params)
+ json.dumps(self.lora_config_params),
)

self.prompt_template = None
if "prompt_template" in hyperparameters.keys():
self.prompt_template = hyperparameters.pop("prompt_template")

def train(self):

args = TrainingArguments(
output_dir=self.path, logging_dir=self.path, **self.training_args
)

def formatting_prompts_func(example):

system_content = example["system"]
user_content = example["user"]
assistant_content = example["assistant"]
Expand Down Expand Up @@ -1463,7 +1455,7 @@ def formatting_prompts_func(example):
peft_config=LoraConfig(**self.lora_config_params),
callbacks=[PGMLCallback(self.project_id, self.model_id)],
)
r_log("info","Creating Supervised Fine Tuning trainer done. Training ... ")
r_log("info", "Creating Supervised Fine Tuning trainer done. Training ... ")

# Train
self.trainer.train()
Expand Down Expand Up @@ -1582,7 +1574,6 @@ def finetune_conversation(
project_id,
model_id,
):

train_dataset = datasets.Dataset.from_dict(
{
"system": system_train,
Expand Down