feat: preload textual model

This commit is contained in:
martabal 2024-09-16 17:53:43 +02:00
parent 4735db8e79
commit 708a53a1eb
No known key found for this signature in database
GPG Key ID: C00196E3148A52BD
17 changed files with 301 additions and 19 deletions

View File

@ -11,7 +11,7 @@ from typing import Any, AsyncGenerator, Callable, Iterator
from zipfile import BadZipFile
import orjson
from fastapi import Depends, FastAPI, File, Form, HTTPException
from fastapi import Depends, FastAPI, File, Form, HTTPException, Response
from fastapi.responses import ORJSONResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from PIL.Image import Image
@ -28,6 +28,7 @@ from .schemas import (
InferenceEntries,
InferenceEntry,
InferenceResponse,
LoadModelEntry,
MessageResponse,
ModelFormat,
ModelIdentity,
@ -124,6 +125,24 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
raise HTTPException(422, "Invalid request format.")
def get_entry(entries: str = Form()) -> LoadModelEntry:
try:
request: PipelineRequest = orjson.loads(entries)
for task, types in request.items():
for type, entry in types.items():
parsed: LoadModelEntry = {
"name": entry["modelName"],
"task": task,
"type": type,
"options": entry.get("options", {}),
"ttl": entry["ttl"] if "ttl" in entry else settings.ttl,
}
return parsed
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
log.error(f"Invalid request format: {e}")
raise HTTPException(422, "Invalid request format.")
app = FastAPI(lifespan=lifespan)
@ -137,6 +156,13 @@ def ping() -> str:
return "pong"
@app.post("/load", response_model=TextResponse)
async def load_model(entry: InferenceEntry = Depends(get_entry)) -> None:
model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
model = await load(model)
return Response(status_code=200)
@app.post("/predict", dependencies=[Depends(update_state)])
async def predict(
entries: InferenceEntries = Depends(get_entries),

View File

@ -109,6 +109,17 @@ class InferenceEntry(TypedDict):
options: dict[str, Any]
class LoadModelEntry(InferenceEntry):
ttl: int
def __init__(self, name: str, task: ModelTask, type: ModelType, options: dict[str, Any], ttl: int):
super().__init__(name=name, task=task, type=type, options=options)
if ttl <= 0:
raise ValueError("ttl must be a positive integer")
self.ttl = ttl
InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]]

View File

@ -337,6 +337,7 @@ Class | Method | HTTP request | Description
- [LibraryStatsResponseDto](doc//LibraryStatsResponseDto.md)
- [LicenseKeyDto](doc//LicenseKeyDto.md)
- [LicenseResponseDto](doc//LicenseResponseDto.md)
- [LoadTextualModelOnConnection](doc//LoadTextualModelOnConnection.md)
- [LogLevel](doc//LogLevel.md)
- [LoginCredentialDto](doc//LoginCredentialDto.md)
- [LoginResponseDto](doc//LoginResponseDto.md)

View File

@ -151,6 +151,7 @@ part 'model/library_response_dto.dart';
part 'model/library_stats_response_dto.dart';
part 'model/license_key_dto.dart';
part 'model/license_response_dto.dart';
part 'model/load_textual_model_on_connection.dart';
part 'model/log_level.dart';
part 'model/login_credential_dto.dart';
part 'model/login_response_dto.dart';

View File

@ -357,6 +357,8 @@ class ApiClient {
return LicenseKeyDto.fromJson(value);
case 'LicenseResponseDto':
return LicenseResponseDto.fromJson(value);
case 'LoadTextualModelOnConnection':
return LoadTextualModelOnConnection.fromJson(value);
case 'LogLevel':
return LogLevelTypeTransformer().decode(value);
case 'LoginCredentialDto':

View File

@ -14,30 +14,36 @@ class CLIPConfig {
/// Returns a new [CLIPConfig] instance.
CLIPConfig({
required this.enabled,
required this.loadTextualModelOnConnection,
required this.modelName,
});
bool enabled;
LoadTextualModelOnConnection loadTextualModelOnConnection;
String modelName;
@override
bool operator ==(Object other) => identical(this, other) || other is CLIPConfig &&
other.enabled == enabled &&
other.loadTextualModelOnConnection == loadTextualModelOnConnection &&
other.modelName == modelName;
@override
int get hashCode =>
// ignore: unnecessary_parenthesis
(enabled.hashCode) +
(loadTextualModelOnConnection.hashCode) +
(modelName.hashCode);
@override
String toString() => 'CLIPConfig[enabled=$enabled, modelName=$modelName]';
String toString() => 'CLIPConfig[enabled=$enabled, loadTextualModelOnConnection=$loadTextualModelOnConnection, modelName=$modelName]';
Map<String, dynamic> toJson() {
final json = <String, dynamic>{};
json[r'enabled'] = this.enabled;
json[r'loadTextualModelOnConnection'] = this.loadTextualModelOnConnection;
json[r'modelName'] = this.modelName;
return json;
}
@ -51,6 +57,7 @@ class CLIPConfig {
return CLIPConfig(
enabled: mapValueOfType<bool>(json, r'enabled')!,
loadTextualModelOnConnection: LoadTextualModelOnConnection.fromJson(json[r'loadTextualModelOnConnection'])!,
modelName: mapValueOfType<String>(json, r'modelName')!,
);
}
@ -100,6 +107,7 @@ class CLIPConfig {
/// The list of required keys that must be present in a JSON.
static const requiredKeys = <String>{
'enabled',
'loadTextualModelOnConnection',
'modelName',
};
}

View File

@ -0,0 +1,107 @@
//
// AUTO-GENERATED FILE, DO NOT MODIFY!
//
// @dart=2.18
// ignore_for_file: unused_element, unused_import
// ignore_for_file: always_put_required_named_parameters_first
// ignore_for_file: constant_identifier_names
// ignore_for_file: lines_longer_than_80_chars
part of openapi.api;
class LoadTextualModelOnConnection {
/// Returns a new [LoadTextualModelOnConnection] instance.
LoadTextualModelOnConnection({
required this.enabled,
required this.ttl,
});
bool enabled;
/// Minimum value: 0
num ttl;
@override
bool operator ==(Object other) => identical(this, other) || other is LoadTextualModelOnConnection &&
other.enabled == enabled &&
other.ttl == ttl;
@override
int get hashCode =>
// ignore: unnecessary_parenthesis
(enabled.hashCode) +
(ttl.hashCode);
@override
String toString() => 'LoadTextualModelOnConnection[enabled=$enabled, ttl=$ttl]';
Map<String, dynamic> toJson() {
final json = <String, dynamic>{};
json[r'enabled'] = this.enabled;
json[r'ttl'] = this.ttl;
return json;
}
/// Returns a new [LoadTextualModelOnConnection] instance and imports its values from
/// [value] if it's a [Map], null otherwise.
// ignore: prefer_constructors_over_static_methods
static LoadTextualModelOnConnection? fromJson(dynamic value) {
if (value is Map) {
final json = value.cast<String, dynamic>();
return LoadTextualModelOnConnection(
enabled: mapValueOfType<bool>(json, r'enabled')!,
ttl: num.parse('${json[r'ttl']}'),
);
}
return null;
}
static List<LoadTextualModelOnConnection> listFromJson(dynamic json, {bool growable = false,}) {
final result = <LoadTextualModelOnConnection>[];
if (json is List && json.isNotEmpty) {
for (final row in json) {
final value = LoadTextualModelOnConnection.fromJson(row);
if (value != null) {
result.add(value);
}
}
}
return result.toList(growable: growable);
}
static Map<String, LoadTextualModelOnConnection> mapFromJson(dynamic json) {
final map = <String, LoadTextualModelOnConnection>{};
if (json is Map && json.isNotEmpty) {
json = json.cast<String, dynamic>(); // ignore: parameter_assignments
for (final entry in json.entries) {
final value = LoadTextualModelOnConnection.fromJson(entry.value);
if (value != null) {
map[entry.key] = value;
}
}
}
return map;
}
// maps a json object with a list of LoadTextualModelOnConnection-objects as value to a dart map
static Map<String, List<LoadTextualModelOnConnection>> mapListFromJson(dynamic json, {bool growable = false,}) {
final map = <String, List<LoadTextualModelOnConnection>>{};
if (json is Map && json.isNotEmpty) {
// ignore: parameter_assignments
json = json.cast<String, dynamic>();
for (final entry in json.entries) {
map[entry.key] = LoadTextualModelOnConnection.listFromJson(entry.value, growable: growable,);
}
}
return map;
}
/// The list of required keys that must be present in a JSON.
static const requiredKeys = <String>{
'enabled',
'ttl',
};
}

View File

@ -8603,12 +8603,16 @@
"enabled": {
"type": "boolean"
},
"loadTextualModelOnConnection": {
"$ref": "#/components/schemas/LoadTextualModelOnConnection"
},
"modelName": {
"type": "string"
}
},
"required": [
"enabled",
"loadTextualModelOnConnection",
"modelName"
],
"type": "object"
@ -9433,6 +9437,23 @@
],
"type": "object"
},
"LoadTextualModelOnConnection": {
"properties": {
"enabled": {
"type": "boolean"
},
"ttl": {
"format": "int64",
"minimum": 0,
"type": "number"
}
},
"required": [
"enabled",
"ttl"
],
"type": "object"
},
"LogLevel": {
"enum": [
"verbose",

View File

@ -1100,8 +1100,13 @@ export type SystemConfigLoggingDto = {
enabled: boolean;
level: LogLevel;
};
export type LoadTextualModelOnConnection = {
enabled: boolean;
ttl: number;
};
export type ClipConfig = {
enabled: boolean;
loadTextualModelOnConnection: LoadTextualModelOnConnection;
modelName: string;
};
export type DuplicateDetectionConfig = {

View File

@ -120,6 +120,10 @@ export interface SystemConfig {
clip: {
enabled: boolean;
modelName: string;
loadTextualModelOnConnection: {
enabled: boolean;
ttl: number;
};
};
duplicateDetection: {
enabled: boolean;
@ -270,6 +274,10 @@ export const defaults = Object.freeze<SystemConfig>({
clip: {
enabled: true,
modelName: 'ViT-B-32__openai',
loadTextualModelOnConnection: {
enabled: false,
ttl: 300,
},
},
duplicateDetection: {
enabled: true,

View File

@ -1,6 +1,6 @@
import { ApiProperty } from '@nestjs/swagger';
import { Type } from 'class-transformer';
import { IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
import { IsNotEmpty, IsNumber, IsObject, IsString, Max, Min, ValidateNested } from 'class-validator';
import { ValidateBoolean } from 'src/validation';
export class TaskConfig {
@ -14,7 +14,20 @@ export class ModelConfig extends TaskConfig {
modelName!: string;
}
export class CLIPConfig extends ModelConfig {}
export class LoadTextualModelOnConnection extends TaskConfig {
@IsNumber()
@Min(0)
@Type(() => Number)
@ApiProperty({ type: 'number', format: 'int64' })
ttl!: number;
}
export class CLIPConfig extends ModelConfig {
@Type(() => LoadTextualModelOnConnection)
@ValidateNested()
@IsObject()
loadTextualModelOnConnection!: LoadTextualModelOnConnection;
}
export class DuplicateDetectionConfig extends TaskConfig {
@IsNumber()

View File

@ -24,13 +24,17 @@ export type ModelPayload = { imagePath: string } | { text: string };
type ModelOptions = { modelName: string };
export interface LoadModelOptions extends ModelOptions {
ttl: number;
}
export type FaceDetectionOptions = ModelOptions & { minScore: number };
type VisualResponse = { imageHeight: number; imageWidth: number };
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions | LoadModelOptions } };
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
export type FacialRecognitionRequest = {
@ -54,4 +58,5 @@ export interface IMachineLearningRepository {
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
loadTextModel(url: string, config: ModelOptions): Promise<void>;
}

View File

@ -9,6 +9,7 @@ import {
WebSocketServer,
} from '@nestjs/websockets';
import { Server, Socket } from 'socket.io';
import { SystemConfigCore } from 'src/cores/system-config.core';
import {
ArgsOf,
ClientEventMap,
@ -19,6 +20,8 @@ import {
ServerEventMap,
} from 'src/interfaces/event.interface';
import { ILoggerRepository } from 'src/interfaces/logger.interface';
import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface';
import { AuthService } from 'src/services/auth.service';
import { Instrumentation } from 'src/utils/instrumentation';
@ -33,6 +36,7 @@ type EmitHandlers = Partial<{ [T in EmitEvent]: EmitHandler<T>[] }>;
@Injectable()
export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit, IEventRepository {
private emitHandlers: EmitHandlers = {};
private configCore: SystemConfigCore;
@WebSocketServer()
private server?: Server;
@ -41,8 +45,11 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
private moduleRef: ModuleRef,
private eventEmitter: EventEmitter2,
@Inject(ILoggerRepository) private logger: ILoggerRepository,
@Inject(IMachineLearningRepository) private machineLearningRepository: IMachineLearningRepository,
@Inject(ISystemMetadataRepository) systemMetadataRepository: ISystemMetadataRepository,
) {
this.logger.setContext(EventRepository.name);
this.configCore = SystemConfigCore.create(systemMetadataRepository, this.logger);
}
afterInit(server: Server) {
@ -68,6 +75,16 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
queryParams: {},
metadata: { adminRoute: false, sharedLinkRoute: false, uri: '/api/socket.io' },
});
if ('background' in client.handshake.query && client.handshake.query.background === 'false') {
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
if (machineLearning.clip.loadTextualModelOnConnection.enabled) {
try {
this.machineLearningRepository.loadTextModel(machineLearning.url, machineLearning.clip);
} catch (error) {
this.logger.warn(error);
}
}
}
await client.join(auth.user.id);
if (auth.session) {
await client.join(auth.session.id);

View File

@ -20,13 +20,9 @@ const errorPrefix = 'Machine learning request';
@Injectable()
export class MachineLearningRepository implements IMachineLearningRepository {
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
const formData = await this.getFormData(payload, config);
const formData = await this.getFormData(config, payload);
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
(error: Error | any) => {
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
},
);
const res = await this.fetchData(url, '/predict', formData);
if (res.status >= 400) {
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
@ -34,6 +30,25 @@ export class MachineLearningRepository implements IMachineLearningRepository {
return res.json();
}
private async fetchData(url: string, path: string, formData?: FormData): Promise<Response> {
const res = await fetch(new URL(path, url), { method: 'POST', body: formData }).catch((error: Error | any) => {
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
});
return res;
}
async loadTextModel(url: string, { modelName, loadTextualModelOnConnection: { ttl } }: CLIPConfig) {
try {
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, ttl } } };
const formData = await this.getFormData(request);
const res = await this.fetchData(url, '/load', formData);
if (res.status >= 400) {
throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`);
}
} catch (error) {}
}
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
const request = {
[ModelTask.FACIAL_RECOGNITION]: {
@ -61,16 +76,17 @@ export class MachineLearningRepository implements IMachineLearningRepository {
return response[ModelTask.SEARCH];
}
private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise<FormData> {
private async getFormData(config: MachineLearningRequest, payload?: ModelPayload): Promise<FormData> {
const formData = new FormData();
formData.append('entries', JSON.stringify(config));
if ('imagePath' in payload) {
formData.append('image', new Blob([await readFile(payload.imagePath)]));
} else if ('text' in payload) {
formData.append('text', payload.text);
} else {
throw new Error('Invalid input');
if (payload) {
if ('imagePath' in payload) {
formData.append('image', new Blob([await readFile(payload.imagePath)]));
} else if ('text' in payload) {
formData.append('text', payload.text);
} else {
throw new Error('Invalid input');
}
}
return formData;

View File

@ -75,6 +75,38 @@
</FormatMessage>
</p>
</SettingInputField>
<SettingAccordion
key="Preload clip model"
title={$t('admin.machine_learning_preload_model')}
subtitle={$t('admin.machine_learning_preload_model_setting_description')}
>
<div class="ml-4 mt-4 flex flex-col gap-4">
<SettingSwitch
title={$t('admin.machine_learning_preload_model_enabled')}
subtitle={$t('admin.machine_learning_preload_model_enabled_description')}
bind:checked={config.machineLearning.clip.loadTextualModelOnConnection.enabled}
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.clip.enabled}
/>
<hr />
<SettingInputField
inputType={SettingInputFieldType.NUMBER}
label={$t('admin.machine_learning_preload_model_ttl')}
bind:value={config.machineLearning.clip.loadTextualModelOnConnection.ttl}
step="1"
min={0}
desc={$t('admin.machine_learning_max_detection_distance_description')}
disabled={disabled ||
!config.machineLearning.enabled ||
!config.machineLearning.clip.enabled ||
!config.machineLearning.clip.loadTextualModelOnConnection.enabled}
isEdited={config.machineLearning.clip.loadTextualModelOnConnection.ttl !==
savedConfig.machineLearning.clip.loadTextualModelOnConnection.ttl}
/>
</div>
</SettingAccordion>
</div>
</SettingAccordion>

View File

@ -114,6 +114,12 @@
"machine_learning_min_detection_score_description": "Minimum confidence score for a face to be detected from 0-1. Lower values will detect more faces but may result in false positives.",
"machine_learning_min_recognized_faces": "Minimum recognized faces",
"machine_learning_min_recognized_faces_description": "The minimum number of recognized faces for a person to be created. Increasing this makes Facial Recognition more precise at the cost of increasing the chance that a face is not assigned to a person.",
"machine_learning_preload_model": "Preload model",
"machine_learning_preload_model_enabled": "Enable preload model",
"machine_learning_preload_model_enabled_description": "Preload the textual model during the connexion instead of during the first search",
"machine_learning_preload_model_setting_description": "Preload the textual model during the connexion",
"machine_learning_preload_model_ttl": "Inactivity time before a model in unloaded",
"machine_learning_preload_model_ttl_description": "Preload the textual model during the connexion",
"machine_learning_settings": "Machine Learning Settings",
"machine_learning_settings_description": "Manage machine learning features and settings",
"machine_learning_smart_search": "Smart Search",

View File

@ -35,6 +35,9 @@ const websocket: Socket<Events> = io({
reconnection: true,
forceNew: true,
autoConnect: false,
query: {
background: false,
},
});
export const websocketStore = {