🌐 AI搜索 & 代理 主页
Skip to content

Commit 74bfa5e

Browse files
committed
move models dir inside pg\
1 parent 4867e50 commit 74bfa5e

File tree

5 files changed

+55
-17
lines changed

5 files changed

+55
-17
lines changed

README.md

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,6 @@ sudo python3 setup.py install
7171
cd ../
7272
```
7373

74-
Create a directory for PostgresML to store its models (this is currently hardcoded):
75-
76-
```
77-
mkdir -p /app/models
78-
chown postgres:postgres /app/models
79-
```
80-
8174
Run the test:
8275

8376
```

pgml/pgml/score.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""Score"""
2+
3+
import os
4+
import pickle
5+
6+
from pgml.exceptions import PgMLException
7+
8+
9+
def load(name, source):
10+
"""Load a model from file."""
11+
path = os.path.join(source, name)
12+
13+
if not os.path.exists(path):
14+
raise PgMLException(f"Model source directory `{path}` does not exist.")
15+
16+
with open(path, "rb") as f:
17+
return pickle.load(f)

pgml/pgml/sql.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tools to run SQL.
22
"""
3+
import os
34

45

56
def all_rows(cursor):
@@ -11,3 +12,21 @@ def all_rows(cursor):
1112

1213
for row in rows:
1314
yield row
15+
16+
17+
def models_directory(plpy):
18+
"""Get the directory where we store our models."""
19+
data_directory = plpy.execute(
20+
"""
21+
SELECT setting FROM pg_settings WHERE name = 'data_directory'
22+
""",
23+
1,
24+
)[0]["setting"]
25+
26+
models_dir = os.path.join(data_directory, "pgml_models")
27+
28+
# TODO: Ideally this happens during extension installation.
29+
if not os.path.exists(models_dir):
30+
os.mkdir(models_dir, 0o770)
31+
32+
return models_dir

pgml/pgml/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
from sklearn.metrics import mean_squared_error, r2_score
99

1010
import pickle
11+
import os
1112

1213
from pgml.sql import all_rows
1314
from pgml.exceptions import PgMLException
1415
from pgml.validate import check_type
1516

1617

17-
def train(cursor, y_column, name, save=True):
18+
def train(cursor, y_column, name, save=True, destination="/tmp/pgml_models"):
1819
"""Train the model on data on some rows.
1920
2021
Arguments:
@@ -65,7 +66,8 @@ def train(cursor, y_column, name, save=True):
6566
msq = mean_squared_error(y_test, y_pred)
6667
r2 = r2_score(y_test, y_pred)
6768

68-
path = f"/app/models/{name}"
69+
path = os.path.join(destination, name)
70+
6971
if save:
7072
with open(path, "wb") as f:
7173
pickle.dump(lr, f)

sql/install.sql

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ CREATE TABLE pgml.model_versions(
2121
name VARCHAR,
2222
location VARCHAR NULL,
2323
data_source TEXT,
24+
y_column VARCHAR,
2425
started_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP,
2526
ended_at TIMESTAMP WITHOUT TIME ZONE NULL,
2627
mean_squared_error DOUBLE PRECISION,
@@ -50,29 +51,34 @@ CREATE OR REPLACE FUNCTION pgml_train(table_name TEXT, y TEXT)
5051
RETURNS TEXT
5152
AS $$
5253
from pgml.train import train
54+
from pgml.sql import models_directory
55+
import os
5356

5457
data_source = f"SELECT * FROM {table_name}"
5558

5659
# Start training.
5760
start = plpy.execute(f"""
5861
INSERT INTO pgml.model_versions
59-
(name, data_source)
60-
VALUES ('{table_name}', '{data_source}')
62+
(name, data_source, y_column)
63+
VALUES
64+
('{table_name}', '{data_source}', '{y}')
6165
RETURNING *""", 1)
6266

6367
id_ = start[0]["id"]
6468
name = f"{table_name}_{id_}"
6569

70+
destination = models_directory(plpy)
71+
6672
# Train!
67-
location, msq, r2 = train(plpy.cursor(data_source), y_column=y, name=name)
73+
location, msq, r2 = train(plpy.cursor(data_source), y_column=y, name=name, destination=destination)
6874

6975
plpy.execute(f"""
7076
UPDATE pgml.model_versions
7177
SET location = '{location}',
7278
successful = true,
73-
ended_at = NOW(),
7479
mean_squared_error = '{msq}',
75-
r2_score = '{r2}'
80+
r2_score = '{r2}',
81+
ended_at = clock_timestamp()
7682
WHERE id = {id_}""")
7783

7884
return name
@@ -86,14 +92,15 @@ DROP FUNCTION pgml_score(model_name TEXT, VARIADIC features DOUBLE PRECISION[]);
8692
CREATE OR REPLACE FUNCTION pgml_score(model_name TEXT, VARIADIC features DOUBLE PRECISION[])
8793
RETURNS DOUBLE PRECISION
8894
AS $$
95+
from pgml.sql import models_directory
96+
from pgml.score import load
8997
import pickle
9098

9199
if model_name in SD:
92100
model = SD[model_name]
93101
else:
94-
with open(f"/app/models/{model_name}", "rb") as f:
95-
model = pickle.load(f)
96-
SD[model_name] = model
102+
SD[model_name] = load(model_name, models_directory(plpy))
103+
model = SD[model_name]
97104

98105
scores = model.predict([features,])
99106
return scores[0]

0 commit comments

Comments
 (0)