@@ -21,6 +21,7 @@ CREATE TABLE pgml.model_versions(
2121 name VARCHAR ,
2222 location VARCHAR NULL ,
2323 data_source TEXT ,
24+ y_column VARCHAR ,
2425 started_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP ,
2526 ended_at TIMESTAMP WITHOUT TIME ZONE NULL ,
2627 mean_squared_error DOUBLE PRECISION ,
@@ -50,29 +51,34 @@ CREATE OR REPLACE FUNCTION pgml_train(table_name TEXT, y TEXT)
5051RETURNS TEXT
5152AS $$
5253 from pgml .train import train
54+ from pgml .sql import models_directory
55+ import os
5356
5457 data_source = f" SELECT * FROM {table_name}"
5558
5659 # Start training.
5760 start = plpy .execute (f" " "
5861 INSERT INTO pgml.model_versions
59- (name, data_source)
60- VALUES ('{table_name}', '{data_source}')
62+ (name, data_source, y_column)
63+ VALUES
64+ ('{table_name}', '{data_source}', '{y}')
6165 RETURNING *" " " , 1 )
6266
6367 id_ = start[0 ][" id" ]
6468 name = f" {table_name}_{id_}"
6569
70+ destination = models_directory(plpy)
71+
6672 # Train!
67- location, msq, r2 = train(plpy .cursor (data_source), y_column= y, name= name)
73+ location, msq, r2 = train(plpy .cursor (data_source), y_column= y, name= name, destination = destination )
6874
6975 plpy .execute (f" " "
7076 UPDATE pgml.model_versions
7177 SET location = '{location}',
7278 successful = true,
73- ended_at = NOW(),
7479 mean_squared_error = '{msq}',
75- r2_score = '{r2}'
80+ r2_score = '{r2}',
81+ ended_at = clock_timestamp()
7682 WHERE id = {id_}" " " )
7783
7884 return name
@@ -86,14 +92,15 @@ DROP FUNCTION pgml_score(model_name TEXT, VARIADIC features DOUBLE PRECISION[]);
8692CREATE OR REPLACE FUNCTION pgml_score (model_name TEXT , VARIADIC features DOUBLE PRECISION [])
8793RETURNS DOUBLE PRECISION
8894AS $$
95+ from pgml .sql import models_directory
96+ from pgml .score import load
8997 import pickle
9098
9199 if model_name in SD:
92100 model = SD[model_name]
93101 else:
94- with open(f" /app/models/{model_name}" , " rb" ) as f:
95- model = pickle .load (f)
96- SD[model_name] = model
102+ SD[model_name] = load(model_name, models_directory(plpy))
103+ model = SD[model_name]
97104
98105 scores = model .predict ([features,])
99106 return scores[0 ]
0 commit comments