Update script/serve_model.py
Browse files- script/serve_model.py +12 -1
script/serve_model.py
CHANGED
|
@@ -301,6 +301,17 @@ if __name__ == "__main__":
|
|
| 301 |
crs_model = CRSModel(crs_model=args.crs_model, **model_args)
|
| 302 |
logger.info(f"Loaded {args.crs_model} model.")
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
# Start CRS Flask server
|
| 305 |
-
crs_server = CRSFlaskServer(
|
|
|
|
|
|
|
| 306 |
crs_server.start(args.host, args.port)
|
|
|
|
| 301 |
crs_model = CRSModel(crs_model=args.crs_model, **model_args)
|
| 302 |
logger.info(f"Loaded {args.crs_model} model.")
|
| 303 |
|
| 304 |
+
# Generation arguments
|
| 305 |
+
response_generation_args = {}
|
| 306 |
+
if args.crs_model == "unicrs":
|
| 307 |
+
response_generation_args = {
|
| 308 |
+
"movie_token": (
|
| 309 |
+
"<movie>" if args.kg_dataset.startswith("redial") else "<mask>"
|
| 310 |
+
),
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
# Start CRS Flask server
|
| 314 |
+
crs_server = CRSFlaskServer(
|
| 315 |
+
crs_model, args.kg_dataset, response_generation_args
|
| 316 |
+
)
|
| 317 |
crs_server.start(args.host, args.port)
|