🌐 AI搜索 & 代理 主页
Skip to content

Commit c24419e

Browse files
committed
rollbacks
1 parent 5f69604 commit c24419e

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

pgml/pgml/model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,13 +440,23 @@ def train(
440440
)
441441

442442
snapshot = Snapshot.create(relation_name, y_column_name, test_size, test_sampling)
443-
best_model = None
444-
best_error = None
443+
deployed = Model.find_deployed(project.id)
444+
445+
# Let's assume that the deployed model is better for now.
446+
best_model = deployed
447+
best_error = best_model.mean_squared_error if best_model else None
445448

446449
for algorithm_name in algorithms:
447450
model = Model.create(project, snapshot, algorithm_name)
448451
model.fit(snapshot)
452+
453+
# Find the better model and deploy that.
449454
if best_error is None or model.mean_squared_error < best_error:
450455
best_error = model.mean_squared_error
451456
best_model = model
452-
best_model.deploy()
457+
458+
if deployed and deployed.id == best_model.id:
459+
return "rolled back"
460+
else:
461+
best_model.deploy()
462+
return "deployed"

sql/install.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ RETURNS TABLE(project_name TEXT, objective TEXT, status TEXT)
109109
AS $$
110110
from pgml.model import train
111111

112-
train(project_name, objective, relation_name, y_column_name)
112+
status = train(project_name, objective, relation_name, y_column_name)
113113

114-
return [(project_name, objective, "deployed")]
114+
return [(project_name, objective, status)]
115115
$$ LANGUAGE plpython3u;
116116

117117
---

0 commit comments

Comments
 (0)