285 lines
15 KiB
Java
285 lines
15 KiB
Java
|
|
package com.labelsys.backend.service;
|
||
|
|
|
||
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
||
|
|
import com.labelsys.backend.common.ResultCode;
|
||
|
|
import com.labelsys.backend.common.exception.BusinessException;
|
||
|
|
import com.labelsys.backend.context.LoginUser;
|
||
|
|
import com.labelsys.backend.dto.common.PageResult;
|
||
|
|
import com.labelsys.backend.dto.request.AnnotationTaskPageQuery;
|
||
|
|
import com.labelsys.backend.dto.request.CreateAnnotationTaskRequest;
|
||
|
|
import com.labelsys.backend.dto.request.UpdateAnnotationTaskRequest;
|
||
|
|
import com.labelsys.backend.dto.response.AnnotationTaskResponse;
|
||
|
|
import com.labelsys.backend.dto.response.TaskModelConfigResponse;
|
||
|
|
import com.labelsys.backend.dto.response.TaskPromptConfigResponse;
|
||
|
|
import com.labelsys.backend.entity.AnnotationTask;
|
||
|
|
import com.labelsys.backend.entity.AnnotationTaskResource;
|
||
|
|
import com.labelsys.backend.entity.SourceResource;
|
||
|
|
import com.labelsys.backend.enums.SourceStatus;
|
||
|
|
import com.labelsys.backend.enums.TaskStatus;
|
||
|
|
import com.labelsys.backend.mapper.AnnotationTaskMapper;
|
||
|
|
import com.labelsys.backend.mapper.AnnotationTaskResourceMapper;
|
||
|
|
import com.labelsys.backend.mapper.SourceResourceMapper;
|
||
|
|
import com.labelsys.backend.util.IdGenerator;
|
||
|
|
import java.util.ArrayList;
|
||
|
|
import java.util.Comparator;
|
||
|
|
import java.util.HashSet;
|
||
|
|
import java.util.List;
|
||
|
|
import java.util.Set;
|
||
|
|
import lombok.RequiredArgsConstructor;
|
||
|
|
import lombok.extern.slf4j.Slf4j;
|
||
|
|
import org.springframework.stereotype.Service;
|
||
|
|
import org.springframework.transaction.annotation.Transactional;
|
||
|
|
import org.springframework.util.StringUtils;
|
||
|
|
|
||
|
|
@Slf4j
|
||
|
|
@Service
|
||
|
|
@RequiredArgsConstructor
|
||
|
|
public class AnnotationTaskService {
|
||
|
|
|
||
|
|
private final AnnotationTaskMapper annotationTaskMapper;
|
||
|
|
private final AnnotationTaskResourceMapper annotationTaskResourceMapper;
|
||
|
|
private final SourceResourceMapper sourceResourceMapper;
|
||
|
|
private final SysConfigService sysConfigService;
|
||
|
|
private final DataPermissionService dataPermissionService;
|
||
|
|
|
||
|
|
@Transactional
|
||
|
|
public AnnotationTaskResponse createTask(LoginUser currentUser, CreateAnnotationTaskRequest request) {
|
||
|
|
List<SourceResource> resources = loadAndValidateResources(currentUser, request.resourceIds());
|
||
|
|
SysConfigService.ResolvedModelConfig extractModel = sysConfigService.resolveModelConfig(currentUser, request.extractModel());
|
||
|
|
SysConfigService.ResolvedModelConfig verifyModel = sysConfigService.resolveModelConfig(currentUser, request.verifyModel());
|
||
|
|
SysConfigService.ResolvedPromptConfig extractPrompt = sysConfigService.resolvePromptConfig(currentUser, request.extractPrompt());
|
||
|
|
SysConfigService.ResolvedPromptConfig verifyPrompt = sysConfigService.resolvePromptConfig(currentUser, request.verifyPrompt());
|
||
|
|
|
||
|
|
AnnotationTask task = AnnotationTask.builder()
|
||
|
|
.id(IdGenerator.nextId())
|
||
|
|
.companyId(currentUser.companyId())
|
||
|
|
.creatorId(currentUser.userId())
|
||
|
|
.creatorRole(currentUser.role())
|
||
|
|
.taskName(request.taskName())
|
||
|
|
.industryType(defaultIndustryType(request.industryType()))
|
||
|
|
.taskType(defaultTaskType(request.taskType()))
|
||
|
|
.extractModelConfigId(extractModel.configId())
|
||
|
|
.extractModelName(extractModel.modelName())
|
||
|
|
.extractModelUrl(extractModel.modelUrl())
|
||
|
|
.extractModelApiKey(extractModel.apiKey())
|
||
|
|
.verifyModelConfigId(verifyModel.configId())
|
||
|
|
.verifyModelName(verifyModel.modelName())
|
||
|
|
.verifyModelUrl(verifyModel.modelUrl())
|
||
|
|
.verifyModelApiKey(verifyModel.apiKey())
|
||
|
|
.extractPromptConfigId(extractPrompt.configId())
|
||
|
|
.extractPrompt(extractPrompt.promptText())
|
||
|
|
.verifyPromptConfigId(verifyPrompt.configId())
|
||
|
|
.verifyPrompt(verifyPrompt.promptText())
|
||
|
|
.taskStatus(TaskStatus.PENDING.name())
|
||
|
|
.isDeleted(false)
|
||
|
|
.build();
|
||
|
|
annotationTaskMapper.insert(task);
|
||
|
|
saveTaskBindings(task.getId(), currentUser.companyId(), resources);
|
||
|
|
log.info("created annotation task, companyId={}, userId={}, taskId={}, resourceCount={}",
|
||
|
|
currentUser.companyId(), currentUser.userId(), task.getId(), resources.size());
|
||
|
|
return buildTaskResponse(task, resourceIds(resources), extractModel, verifyModel, extractPrompt, verifyPrompt);
|
||
|
|
}
|
||
|
|
|
||
|
|
@Transactional
|
||
|
|
public AnnotationTaskResponse updateTask(LoginUser currentUser, Long taskId, UpdateAnnotationTaskRequest request) {
|
||
|
|
AnnotationTask task = annotationTaskMapper.findByIdAndCompanyId(taskId, currentUser.companyId());
|
||
|
|
if (task == null) {
|
||
|
|
throw new BusinessException(ResultCode.NOT_FOUND, "任务不存在");
|
||
|
|
}
|
||
|
|
assertTaskPermission(currentUser, task);
|
||
|
|
|
||
|
|
List<Long> currentResourceIds = normalizeIds(annotationTaskResourceMapper.listResourceIdsByTaskId(taskId));
|
||
|
|
List<Long> targetResourceIds = normalizeIds(request.resourceIds());
|
||
|
|
boolean resourcesChanged = !currentResourceIds.equals(targetResourceIds);
|
||
|
|
if (TaskStatus.RUNNING.name().equals(task.getTaskStatus()) && resourcesChanged) {
|
||
|
|
throw new BusinessException(ResultCode.CONFLICT, "运行中的任务不允许修改资源");
|
||
|
|
}
|
||
|
|
|
||
|
|
List<SourceResource> resources = loadAndValidateResources(currentUser, request.resourceIds());
|
||
|
|
SysConfigService.ResolvedModelConfig extractModel = sysConfigService.resolveModelConfig(currentUser, request.extractModel());
|
||
|
|
SysConfigService.ResolvedModelConfig verifyModel = sysConfigService.resolveModelConfig(currentUser, request.verifyModel());
|
||
|
|
SysConfigService.ResolvedPromptConfig extractPrompt = sysConfigService.resolvePromptConfig(currentUser, request.extractPrompt());
|
||
|
|
SysConfigService.ResolvedPromptConfig verifyPrompt = sysConfigService.resolvePromptConfig(currentUser, request.verifyPrompt());
|
||
|
|
|
||
|
|
task.setIndustryType(defaultIndustryType(request.industryType()));
|
||
|
|
task.setTaskType(defaultTaskType(request.taskType()));
|
||
|
|
task.setExtractModelConfigId(extractModel.configId());
|
||
|
|
task.setExtractModelName(extractModel.modelName());
|
||
|
|
task.setExtractModelUrl(extractModel.modelUrl());
|
||
|
|
task.setExtractModelApiKey(extractModel.apiKey());
|
||
|
|
task.setVerifyModelConfigId(verifyModel.configId());
|
||
|
|
task.setVerifyModelName(verifyModel.modelName());
|
||
|
|
task.setVerifyModelUrl(verifyModel.modelUrl());
|
||
|
|
task.setVerifyModelApiKey(verifyModel.apiKey());
|
||
|
|
task.setExtractPromptConfigId(extractPrompt.configId());
|
||
|
|
task.setExtractPrompt(extractPrompt.promptText());
|
||
|
|
task.setVerifyPromptConfigId(verifyPrompt.configId());
|
||
|
|
task.setVerifyPrompt(verifyPrompt.promptText());
|
||
|
|
annotationTaskMapper.updateById(task);
|
||
|
|
|
||
|
|
if (resourcesChanged) {
|
||
|
|
annotationTaskResourceMapper.deleteByTaskId(taskId);
|
||
|
|
saveTaskBindings(taskId, currentUser.companyId(), resources);
|
||
|
|
}
|
||
|
|
log.info("updated annotation task, companyId={}, userId={}, taskId={}, resourcesChanged={}",
|
||
|
|
currentUser.companyId(), currentUser.userId(), taskId, resourcesChanged);
|
||
|
|
return buildTaskResponse(task, resourceIds(resources), extractModel, verifyModel, extractPrompt, verifyPrompt);
|
||
|
|
}
|
||
|
|
|
||
|
|
public AnnotationTaskResponse getTask(LoginUser currentUser, Long taskId) {
|
||
|
|
AnnotationTask task = annotationTaskMapper.findByIdAndCompanyId(taskId, currentUser.companyId());
|
||
|
|
if (task == null) {
|
||
|
|
throw new BusinessException(ResultCode.NOT_FOUND, "任务不存在");
|
||
|
|
}
|
||
|
|
assertTaskPermission(currentUser, task);
|
||
|
|
return buildTaskResponse(task, normalizeIds(annotationTaskResourceMapper.listResourceIdsByTaskId(taskId)));
|
||
|
|
}
|
||
|
|
|
||
|
|
public PageResult<AnnotationTaskResponse> pageTasks(LoginUser currentUser, AnnotationTaskPageQuery query) {
|
||
|
|
LambdaQueryWrapper<AnnotationTask> wrapper = new LambdaQueryWrapper<AnnotationTask>()
|
||
|
|
.eq(AnnotationTask::getCompanyId, currentUser.companyId())
|
||
|
|
.eq(StringUtils.hasText(query.taskType()), AnnotationTask::getTaskType, query.taskType())
|
||
|
|
.eq(StringUtils.hasText(query.taskStatus()), AnnotationTask::getTaskStatus, query.taskStatus())
|
||
|
|
.eq(query.isDeleted() != null, AnnotationTask::getIsDeleted, query.isDeleted())
|
||
|
|
.like(StringUtils.hasText(query.keyword()), AnnotationTask::getTaskName, query.keyword())
|
||
|
|
.orderByDesc(AnnotationTask::getCreatedAt);
|
||
|
|
List<AnnotationTaskResponse> records = annotationTaskMapper.selectList(wrapper).stream()
|
||
|
|
.filter(task -> dataPermissionService.canAccessCreator(currentUser, task.getCreatorId(), task.getCreatorRole()))
|
||
|
|
.filter(task -> query.resourceId() == null || annotationTaskResourceMapper.listResourceIdsByTaskId(task.getId()).contains(query.resourceId()))
|
||
|
|
.sorted(Comparator.comparing(AnnotationTask::getCreatedAt, Comparator.nullsLast(Comparator.naturalOrder())).reversed())
|
||
|
|
.map(task -> buildTaskResponse(task, normalizeIds(annotationTaskResourceMapper.listResourceIdsByTaskId(task.getId()))))
|
||
|
|
.toList();
|
||
|
|
return paginate(records, query.pageNo(), query.pageSize());
|
||
|
|
}
|
||
|
|
|
||
|
|
@Transactional
|
||
|
|
public void deleteTask(LoginUser currentUser, Long taskId) {
|
||
|
|
AnnotationTask task = annotationTaskMapper.findByIdAndCompanyId(taskId, currentUser.companyId());
|
||
|
|
if (task == null) {
|
||
|
|
throw new BusinessException(ResultCode.NOT_FOUND, "任务不存在");
|
||
|
|
}
|
||
|
|
assertTaskPermission(currentUser, task);
|
||
|
|
if (TaskStatus.RUNNING.name().equals(task.getTaskStatus())) {
|
||
|
|
throw new BusinessException(ResultCode.CONFLICT, "运行中的任务不允许删除");
|
||
|
|
}
|
||
|
|
task.setIsDeleted(true);
|
||
|
|
annotationTaskMapper.updateById(task);
|
||
|
|
log.info("deleted annotation task logically, companyId={}, userId={}, taskId={}",
|
||
|
|
currentUser.companyId(), currentUser.userId(), taskId);
|
||
|
|
}
|
||
|
|
|
||
|
|
private List<SourceResource> loadAndValidateResources(LoginUser currentUser, List<Long> resourceIds) {
|
||
|
|
if (resourceIds == null || resourceIds.isEmpty()) {
|
||
|
|
throw new BusinessException(ResultCode.BAD_REQUEST, "任务资源不能为空");
|
||
|
|
}
|
||
|
|
List<Long> normalizedIds = normalizeIds(resourceIds);
|
||
|
|
List<SourceResource> resources = sourceResourceMapper.selectByCompanyIdAndIds(currentUser.companyId(), normalizedIds);
|
||
|
|
if (resources.size() != normalizedIds.size()) {
|
||
|
|
throw new BusinessException(ResultCode.BAD_REQUEST, "存在无效资源");
|
||
|
|
}
|
||
|
|
for (SourceResource resource : resources) {
|
||
|
|
if (!dataPermissionService.canAccessCreator(currentUser, resource.getCreatorId(), resource.getCreatorRole())) {
|
||
|
|
throw new BusinessException(ResultCode.FORBIDDEN, "无权访问资源");
|
||
|
|
}
|
||
|
|
if (!SourceStatus.READY.name().equals(resource.getSourceStatus())) {
|
||
|
|
throw new BusinessException(ResultCode.BAD_REQUEST, "仅允许选择已就绪资源");
|
||
|
|
}
|
||
|
|
}
|
||
|
|
resources.sort(Comparator.comparing(SourceResource::getId));
|
||
|
|
return resources;
|
||
|
|
}
|
||
|
|
|
||
|
|
private void saveTaskBindings(Long taskId, Long companyId, List<SourceResource> resources) {
|
||
|
|
for (SourceResource resource : resources) {
|
||
|
|
annotationTaskResourceMapper.insert(AnnotationTaskResource.builder()
|
||
|
|
.id(IdGenerator.nextId())
|
||
|
|
.companyId(companyId)
|
||
|
|
.taskId(taskId)
|
||
|
|
.resourceId(resource.getId())
|
||
|
|
.build());
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
private AnnotationTaskResponse buildTaskResponse(AnnotationTask task,
|
||
|
|
List<Long> resourceIds,
|
||
|
|
SysConfigService.ResolvedModelConfig extractModel,
|
||
|
|
SysConfigService.ResolvedModelConfig verifyModel,
|
||
|
|
SysConfigService.ResolvedPromptConfig extractPrompt,
|
||
|
|
SysConfigService.ResolvedPromptConfig verifyPrompt) {
|
||
|
|
return new AnnotationTaskResponse(
|
||
|
|
task.getId(),
|
||
|
|
task.getTaskName(),
|
||
|
|
task.getIndustryType(),
|
||
|
|
task.getTaskType(),
|
||
|
|
task.getTaskStatus(),
|
||
|
|
resourceIds,
|
||
|
|
sysConfigService.toResponse(extractModel),
|
||
|
|
sysConfigService.toResponse(verifyModel),
|
||
|
|
sysConfigService.toResponse(extractPrompt),
|
||
|
|
sysConfigService.toResponse(verifyPrompt),
|
||
|
|
task.getCreatedAt(),
|
||
|
|
task.getUpdatedAt());
|
||
|
|
}
|
||
|
|
|
||
|
|
private AnnotationTaskResponse buildTaskResponse(AnnotationTask task, List<Long> resourceIds) {
|
||
|
|
return new AnnotationTaskResponse(
|
||
|
|
task.getId(),
|
||
|
|
task.getTaskName(),
|
||
|
|
task.getIndustryType(),
|
||
|
|
task.getTaskType(),
|
||
|
|
task.getTaskStatus(),
|
||
|
|
resourceIds,
|
||
|
|
new TaskModelConfigResponse(task.getExtractModelConfigId(), null, task.getExtractModelName(),
|
||
|
|
task.getExtractModelUrl(), maskSecret(task.getExtractModelApiKey())),
|
||
|
|
new TaskModelConfigResponse(task.getVerifyModelConfigId(), null, task.getVerifyModelName(),
|
||
|
|
task.getVerifyModelUrl(), maskSecret(task.getVerifyModelApiKey())),
|
||
|
|
new TaskPromptConfigResponse(task.getExtractPromptConfigId(), null, task.getExtractPrompt()),
|
||
|
|
new TaskPromptConfigResponse(task.getVerifyPromptConfigId(), null, task.getVerifyPrompt()),
|
||
|
|
task.getCreatedAt(),
|
||
|
|
task.getUpdatedAt());
|
||
|
|
}
|
||
|
|
|
||
|
|
private List<Long> resourceIds(List<SourceResource> resources) {
|
||
|
|
return resources.stream().map(SourceResource::getId).sorted().toList();
|
||
|
|
}
|
||
|
|
|
||
|
|
private List<Long> normalizeIds(List<Long> resourceIds) {
|
||
|
|
Set<Long> uniqueIds = new HashSet<>(resourceIds);
|
||
|
|
List<Long> sortedIds = new ArrayList<>(uniqueIds);
|
||
|
|
sortedIds.sort(Long::compareTo);
|
||
|
|
return sortedIds;
|
||
|
|
}
|
||
|
|
|
||
|
|
private void assertTaskPermission(LoginUser currentUser, AnnotationTask task) {
|
||
|
|
if (!dataPermissionService.canAccessCreator(currentUser, task.getCreatorId(), task.getCreatorRole())) {
|
||
|
|
throw new BusinessException(ResultCode.FORBIDDEN, "无权操作任务");
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
private String defaultIndustryType(String industryType) {
|
||
|
|
return StringUtils.hasText(industryType) ? industryType : "transport";
|
||
|
|
}
|
||
|
|
|
||
|
|
private String defaultTaskType(String taskType) {
|
||
|
|
return StringUtils.hasText(taskType) ? taskType : "EXTRACT_QA";
|
||
|
|
}
|
||
|
|
|
||
|
|
private String maskSecret(String secret) {
|
||
|
|
if (!StringUtils.hasText(secret)) {
|
||
|
|
return null;
|
||
|
|
}
|
||
|
|
if (secret.length() <= 4) {
|
||
|
|
return "****";
|
||
|
|
}
|
||
|
|
return "****" + secret.substring(secret.length() - 4);
|
||
|
|
}
|
||
|
|
|
||
|
|
private <T> PageResult<T> paginate(List<T> records, Integer pageNo, Integer pageSize) {
|
||
|
|
int actualPageNo = pageNo == null || pageNo < 1 ? 1 : pageNo;
|
||
|
|
int actualPageSize = pageSize == null || pageSize < 1 ? 10 : pageSize;
|
||
|
|
int fromIndex = Math.min((actualPageNo - 1) * actualPageSize, records.size());
|
||
|
|
int toIndex = Math.min(fromIndex + actualPageSize, records.size());
|
||
|
|
return new PageResult<>(records.subList(fromIndex, toIndex), (long) records.size(), actualPageNo, actualPageSize);
|
||
|
|
}
|
||
|
|
}
|