|
|
@@ -1,19 +1,23 @@
|
|
|
-import { DataSource, Repository, ObjectLiteral, DeepPartial } from 'typeorm';
|
|
|
+import { DataSource, Repository, ObjectLiteral, DeepPartial, In } from 'typeorm';
|
|
|
import { z } from '@hono/zod-openapi';
|
|
|
|
|
|
export abstract class GenericCrudService<T extends ObjectLiteral> {
|
|
|
protected repository: Repository<T>;
|
|
|
private userTrackingOptions?: UserTrackingOptions;
|
|
|
|
|
|
+ protected relationFields?: RelationFieldOptions;
|
|
|
+
|
|
|
constructor(
|
|
|
protected dataSource: DataSource,
|
|
|
protected entity: new () => T,
|
|
|
options?: {
|
|
|
userTracking?: UserTrackingOptions;
|
|
|
+ relationFields?: RelationFieldOptions;
|
|
|
}
|
|
|
) {
|
|
|
this.repository = this.dataSource.getRepository(entity);
|
|
|
this.userTrackingOptions = options?.userTracking;
|
|
|
+ this.relationFields = options?.relationFields;
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
@@ -154,13 +158,56 @@ export abstract class GenericCrudService<T extends ObjectLiteral> {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * 创建实体
|
|
|
+ */
|
|
|
+ /**
|
|
|
+ * 处理关联字段
|
|
|
+ */
|
|
|
+ private async handleRelationFields(data: any, entity: T, isUpdate: boolean = false): Promise<void> {
|
|
|
+ if (!this.relationFields) return;
|
|
|
+
|
|
|
+ for (const [fieldName, config] of Object.entries(this.relationFields)) {
|
|
|
+ if (data[fieldName] !== undefined) {
|
|
|
+ const ids = data[fieldName];
|
|
|
+ const relationRepository = this.dataSource.getRepository(config.targetEntity);
|
|
|
+
|
|
|
+ if (ids && Array.isArray(ids) && ids.length > 0) {
|
|
|
+ const relatedEntities = await relationRepository.findBy({ id: In(ids) });
|
|
|
+ (entity as any)[config.relationName] = relatedEntities;
|
|
|
+ } else {
|
|
|
+ (entity as any)[config.relationName] = [];
|
|
|
+ }
|
|
|
+
|
|
|
+ // 清理原始数据中的关联字段
|
|
|
+ delete data[fieldName];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* 创建实体
|
|
|
*/
|
|
|
async create(data: DeepPartial<T>, userId?: string | number): Promise<T> {
|
|
|
const entityData = { ...data };
|
|
|
this.setUserFields(entityData, userId, true);
|
|
|
+
|
|
|
+ // 分离关联字段数据
|
|
|
+ const relationData: any = {};
|
|
|
+ if (this.relationFields) {
|
|
|
+ for (const fieldName of Object.keys(this.relationFields)) {
|
|
|
+ if (fieldName in entityData) {
|
|
|
+ relationData[fieldName] = (entityData as any)[fieldName];
|
|
|
+ delete (entityData as any)[fieldName];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
const entity = this.repository.create(entityData as DeepPartial<T>);
|
|
|
+
|
|
|
+ // 处理关联字段
|
|
|
+ await this.handleRelationFields(relationData, entity);
|
|
|
+
|
|
|
return this.repository.save(entity);
|
|
|
}
|
|
|
|
|
|
@@ -170,8 +217,29 @@ export abstract class GenericCrudService<T extends ObjectLiteral> {
|
|
|
async update(id: number, data: Partial<T>, userId?: string | number): Promise<T | null> {
|
|
|
const updateData = { ...data };
|
|
|
this.setUserFields(updateData, userId, false);
|
|
|
+
|
|
|
+ // 分离关联字段数据
|
|
|
+ const relationData: any = {};
|
|
|
+ if (this.relationFields) {
|
|
|
+ for (const fieldName of Object.keys(this.relationFields)) {
|
|
|
+ if (fieldName in updateData) {
|
|
|
+ relationData[fieldName] = (updateData as any)[fieldName];
|
|
|
+ delete (updateData as any)[fieldName];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 先更新基础字段
|
|
|
await this.repository.update(id, updateData);
|
|
|
- return this.getById(id);
|
|
|
+
|
|
|
+ // 获取完整实体并处理关联字段
|
|
|
+ const entity = await this.getById(id);
|
|
|
+ if (!entity) return null;
|
|
|
+
|
|
|
+ // 处理关联字段
|
|
|
+ await this.handleRelationFields(relationData, entity, true);
|
|
|
+
|
|
|
+ return this.repository.save(entity);
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
@@ -196,6 +264,14 @@ export interface UserTrackingOptions {
|
|
|
userIdField?: string;
|
|
|
}
|
|
|
|
|
|
+export interface RelationFieldOptions {
|
|
|
+ [fieldName: string]: {
|
|
|
+ relationName: string;
|
|
|
+ targetEntity: new () => any;
|
|
|
+ joinTableName?: string;
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
export type CrudOptions<
|
|
|
T extends ObjectLiteral,
|
|
|
CreateSchema extends z.ZodSchema = z.ZodSchema,
|
|
|
@@ -212,4 +288,5 @@ export type CrudOptions<
|
|
|
relations?: string[];
|
|
|
middleware?: any[];
|
|
|
userTracking?: UserTrackingOptions;
|
|
|
+ relationFields?: RelationFieldOptions;
|
|
|
};
|