diff --git a/.gitignore b/.gitignore index 3d40740..4b05b93 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ CLAUDE.md *.iws .agents/ .history/ +.trae/ logs/ # ========================================== diff --git a/src/main/java/com/labelsys/backend/controller/AnnotationAgentConfigController.java b/src/main/java/com/labelsys/backend/controller/AnnotationAgentConfigController.java new file mode 100644 index 0000000..b917358 --- /dev/null +++ b/src/main/java/com/labelsys/backend/controller/AnnotationAgentConfigController.java @@ -0,0 +1,46 @@ +package com.labelsys.backend.controller; + +import com.labelsys.backend.annotation.RequirePosition; +import com.labelsys.backend.common.Result; +import com.labelsys.backend.context.UserContext; +import com.labelsys.backend.dto.request.SaveAgentConfigRequest; +import com.labelsys.backend.dto.response.AgentConfigListResponse; +import com.labelsys.backend.entity.AnnotationAgentConfig; +import com.labelsys.backend.enums.UserPosition; +import com.labelsys.backend.service.AnnotationAgentConfigService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.validation.Valid; +import lombok.RequiredArgsConstructor; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +@Tag(name = "Agent配置管理") +@RestController +@RequestMapping("/api/agent-configs") +@RequiredArgsConstructor +public class AnnotationAgentConfigController { + + private final AnnotationAgentConfigService annotationAgentConfigService; + + @Operation(summary = "保存Agent配置") + @PostMapping + @RequirePosition(UserPosition.ADMIN) // 仅限ADMIN及以上岗位访问 + public Result save(@Valid @RequestBody SaveAgentConfigRequest request) { + return Result.success( + annotationAgentConfigService.saveAgentConfigs(UserContext.requireUser(), request)); + } + + @Operation(summary = "获取公司Agent对应配置列表:模型配置和提示词配置") + @GetMapping + @RequirePosition(UserPosition.ADMIN) // 仅限ADMIN及以上岗位访问 + public Result list() { + var user = UserContext.requireUser(); + return Result.success( + annotationAgentConfigService.getAgentConfigsForCompany(user.companyId())); + } + +} \ No newline at end of file diff --git a/src/main/java/com/labelsys/backend/controller/SysConfigController.java b/src/main/java/com/labelsys/backend/controller/SysConfigController.java index 50b71f2..c9c9124 100644 --- a/src/main/java/com/labelsys/backend/controller/SysConfigController.java +++ b/src/main/java/com/labelsys/backend/controller/SysConfigController.java @@ -1,5 +1,20 @@ package com.labelsys.backend.controller; +import com.labelsys.backend.annotation.RequirePosition; +import com.labelsys.backend.common.Result; +import com.labelsys.backend.context.UserContext; +import com.labelsys.backend.dto.common.PageResult; +import com.labelsys.backend.dto.request.SaveSysConfigRequest; +import com.labelsys.backend.dto.request.SysConfigPageQuery; +import com.labelsys.backend.dto.request.UpdateSysConfigRequest; +import com.labelsys.backend.dto.response.SysConfigResponse; +import com.labelsys.backend.enums.UserPosition; +import com.labelsys.backend.service.SysConfigService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.validation.Valid; +import lombok.RequiredArgsConstructor; import org.springdoc.core.annotations.ParameterObject; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; @@ -9,21 +24,6 @@ import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; -import com.labelsys.backend.common.Result; -import com.labelsys.backend.context.UserContext; -import com.labelsys.backend.dto.common.PageResult; -import com.labelsys.backend.dto.request.SaveSysConfigRequest; -import com.labelsys.backend.dto.request.SysConfigPageQuery; -import com.labelsys.backend.dto.request.UpdateSysConfigRequest; -import com.labelsys.backend.dto.response.SysConfigResponse; -import com.labelsys.backend.service.SysConfigService; - -import io.swagger.v3.oas.annotations.Operation; -import io.swagger.v3.oas.annotations.Parameter; -import io.swagger.v3.oas.annotations.tags.Tag; -import jakarta.validation.Valid; -import lombok.RequiredArgsConstructor; - @Tag(name = "系统配置管理") @RestController @RequestMapping("/api/sys-configs") @@ -33,33 +33,35 @@ public class SysConfigController { private final SysConfigService sysConfigService; @Operation(summary = "创建系统配置") - // @RequirePosition(UserPosition.ADMIN) @PostMapping + @RequirePosition(UserPosition.ADMIN) // 仅限ADMIN及以上岗位访问 public Result create(@Valid @RequestBody SaveSysConfigRequest request) { return Result - .success(sysConfigService.toResponse(sysConfigService.saveConfig(UserContext.requireUser(), request))); + .success(sysConfigService.toResponse(sysConfigService.saveConfig(UserContext.requireUser(), request))); } @Operation(summary = "更新系统配置") - // @RequirePosition(UserPosition.ADMIN) @PutMapping("/{id}") + @RequirePosition(UserPosition.ADMIN) // 仅限ADMIN及以上岗位访问 public Result update( - @Parameter(description = "配置ID", example = "191000000000000501") @PathVariable Long id, - @Valid @RequestBody UpdateSysConfigRequest request) { + @Parameter(description = "配置ID", example = "191000000000000501") @PathVariable Long id, + @Valid @RequestBody UpdateSysConfigRequest request) { return Result.success( - sysConfigService.toResponse(sysConfigService.updateConfig(UserContext.requireUser(), id, request))); + sysConfigService.toResponse(sysConfigService.updateConfig(UserContext.requireUser(), id, request))); } @Operation(summary = "分页查询系统配置") @GetMapping + @RequirePosition(UserPosition.ADMIN) // 仅限ADMIN及以上岗位访问 public Result> page(@ParameterObject SysConfigPageQuery query) { return Result.success(sysConfigService.pageConfigs(UserContext.requireUser(), query)); } - @Operation(summary = "查询系统配置详情") + @Operation(summary = "查询配置详情") @GetMapping("/{id}") - public Result - detail(@Parameter(description = "配置ID", example = "191000000000000501") @PathVariable Long id) { - return Result.success(sysConfigService.getConfig(UserContext.requireUser(), id)); + @RequirePosition(UserPosition.ADMIN) // 仅限ADMIN及以上岗位访问 + public Result detail( + @Parameter(description = "配置ID", example = "191000000000000501") @PathVariable Long id) { + return Result.success(sysConfigService.getConfigDetail(UserContext.requireUser(), id)); } } diff --git a/src/main/java/com/labelsys/backend/dto/LlmConfigModel.java b/src/main/java/com/labelsys/backend/dto/LlmConfigModel.java new file mode 100644 index 0000000..7c8a053 --- /dev/null +++ b/src/main/java/com/labelsys/backend/dto/LlmConfigModel.java @@ -0,0 +1,20 @@ +package com.labelsys.backend.dto; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +@Schema(description = "大模型配置模型") +@Data +public class LlmConfigModel { + @Schema(description = "多模态类型枚举") + private String llmType; + + @Schema(description = "模型名称") + private String modelName; + + @Schema(description = "模型调用地址") + private String modelUrl; + + @Schema(description = "API密钥,加密存储") + private String apiKey; +} \ No newline at end of file diff --git a/src/main/java/com/labelsys/backend/dto/request/SaveAgentConfigRequest.java b/src/main/java/com/labelsys/backend/dto/request/SaveAgentConfigRequest.java new file mode 100644 index 0000000..81faaad --- /dev/null +++ b/src/main/java/com/labelsys/backend/dto/request/SaveAgentConfigRequest.java @@ -0,0 +1,28 @@ +package com.labelsys.backend.dto.request; + +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotNull; +import lombok.Data; + +import java.util.Map; + +@Schema(description = "保存Agent配置请求: 请求体中agentConfigs为map结构,key为AgentType枚举值,值为用户选择的模型配置ID和提示词配置ID") +@Data +public class SaveAgentConfigRequest { + + @Schema(description = "Agent配置映射,key为AgentType枚举类型code,value为配置信息。支持的Agent类型:RedTeamAgent(红黑对抗)、AnalyzerAgent(分析)、IndustryClassifierAgent(行业识别)、HallucinationDetectorAgent(幻觉检测)、ReviewerAgent(审查)、regenerator(再次生成)") + @NotNull(message = "Agent配置不能为空") + private Map agentConfigs; + + @Schema(description = "Agent配置项") + @Data + public static class AgentConfigItem { + + @Schema(description = "模型配置ID", example = "191000000000000501") + @NotNull(message = "模型配置ID不能为空") + private Long modelConfigId; + + @Schema(description = "提示词配置ID", example = "191000000000000502") + private Long promptConfigId; + } +} \ No newline at end of file diff --git a/src/main/java/com/labelsys/backend/dto/request/SaveSysConfigRequest.java b/src/main/java/com/labelsys/backend/dto/request/SaveSysConfigRequest.java index 90fb7d2..9ff2967 100644 --- a/src/main/java/com/labelsys/backend/dto/request/SaveSysConfigRequest.java +++ b/src/main/java/com/labelsys/backend/dto/request/SaveSysConfigRequest.java @@ -5,9 +5,10 @@ import jakarta.validation.constraints.NotBlank; @Schema(description = "保存系统配置请求") public record SaveSysConfigRequest( - @Schema(description = "配置类型", example = "MODEL") @NotBlank(message = "配置类型不能为空") String configType, - @Schema(description = "配置名称", example = "qwen-plus-extract") @NotBlank(message = "配置名称不能为空") String configName, - @Schema(description = "配置值", - example = "{\"modelName\":\"qwen-plus\",\"modelUrl\":\"https://dashscope.aliyuncs.com/compatible-mode/v1\",\"apiKey\":\"sk-demo1234\"}") @NotBlank( - message = "配置值不能为空") String configValue, - @Schema(description = "配置状态", example = "ENABLED") String status) {} + @Schema(description = "配置类型: MODEL-大模型配置、PROMPT-提示词配置、SYSTEM-其他配置项", example = "MODEL") @NotBlank(message = "配置类型不能为空") String configType, + @Schema(description = "配置名称", example = "qwen-plus-extract") @NotBlank(message = "配置名称不能为空") String configName, + @Schema(description = "配置值, 当configType为model时候,configValue为大模型配置的json字符串,configType为其他值时,configValue为普通字符串", + example = "{\"modelName\":\"qwen-plus\",\"modelUrl\":\"https://dashscope.aliyuncs.com/compatible-mode/v1\",\"apiKey\":\"sk-demo1234\"}") @NotBlank( + message = "配置值不能为空") String configValue, + @Schema(description = "配置状态", example = "ENABLED") String status) { +} diff --git a/src/main/java/com/labelsys/backend/dto/request/UpdateSysConfigRequest.java b/src/main/java/com/labelsys/backend/dto/request/UpdateSysConfigRequest.java index 7b68f89..10c85c8 100644 --- a/src/main/java/com/labelsys/backend/dto/request/UpdateSysConfigRequest.java +++ b/src/main/java/com/labelsys/backend/dto/request/UpdateSysConfigRequest.java @@ -6,9 +6,9 @@ import jakarta.validation.constraints.NotBlank; @Schema(description = "更新系统配置请求") public record UpdateSysConfigRequest( - @Schema(description = "配置类型", example = "MODEL") String configType, + @Schema(description = "配置类型: MODEL-大模型配置、PROMPT-提示词配置、SYSTEM-其他配置项\", example = \"MODEL\"", example = "MODEL") String configType, @Schema(description = "配置名称", example = "qwen-plus-extract") String configName, - @Schema(description = "配置值", - example = "{\"modelName\":\"qwen-plus\",\"modelUrl\":\"https://dashscope.aliyuncs.com/compatible-mode/v1\",\"apiKey\":\"sk-demo1234\"}") @NotBlank( + @Schema(description = "配置值, 当configType为model时候,configValue为大模型配置的json字符串,configType为其他值时,configValue为普通字符串", + example = "{\"modelName\":\"qwen-plus\",\"modelUrl\":\"https://dashscope.aliyuncs.com/compatible-mode/v1\",\"apiKey\":\"sk-demo1234\"}") @NotBlank( message = "配置值不能为空") String configValue, @Schema(description = "配置状态", example = "ENABLED") String status) {} diff --git a/src/main/java/com/labelsys/backend/dto/response/AgentConfigInfoResponse.java b/src/main/java/com/labelsys/backend/dto/response/AgentConfigInfoResponse.java new file mode 100644 index 0000000..bec04bf --- /dev/null +++ b/src/main/java/com/labelsys/backend/dto/response/AgentConfigInfoResponse.java @@ -0,0 +1,57 @@ +package com.labelsys.backend.dto.response; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Builder; +import lombok.Data; + +@Schema(description = "Agent配置信息响应") +@Data +@Builder +public class AgentConfigInfoResponse { + + @Schema(description = "配置ID") + private Long id; + + @Schema(description = "Agent类型。支持的Agent类型:RedTeamAgent(红黑对抗)、AnalyzerAgent(分析)、IndustryClassifierAgent(行业识别)、HallucinationDetectorAgent(幻觉检测)、ReviewerAgent(审查)、regenerator(再次生成)") + private String agentType; + + @Schema(description = "模型配置信息") + private ModelConfigInfo modelConfig; + + @Schema(description = "Prompt配置信息") + private PromptConfigInfo promptConfig; + + @Schema(description = "创建时间") + private java.time.LocalDateTime createdAt; + + @Data + @Builder + @Schema(description = "模型配置信息") + public static class ModelConfigInfo { + @Schema(description = "模型配置ID") + private Long id; + + @Schema(description = "模型名称") + private String modelName; + + @Schema(description = "模型URL") + private String modelUrl; + + @Schema(description = "API密钥") + private String apiKey; + + @Schema(description = "LLM类型") + private String llmType; + } + + @Data + @Builder + @Schema(description = "Prompt配置信息") + public static class PromptConfigInfo { + @Schema(description = "Prompt配置ID") + private Long id; + + @Schema(description = "Prompt文本") + private String promptText; + } +} \ No newline at end of file diff --git a/src/main/java/com/labelsys/backend/dto/response/AgentConfigListResponse.java b/src/main/java/com/labelsys/backend/dto/response/AgentConfigListResponse.java new file mode 100644 index 0000000..1fb864e --- /dev/null +++ b/src/main/java/com/labelsys/backend/dto/response/AgentConfigListResponse.java @@ -0,0 +1,23 @@ +package com.labelsys.backend.dto.response; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Builder; +import lombok.Data; + +import java.util.List; +import java.util.Map; + +@Schema(description = "Agent配置列表响应:返回结果agentConfigs为公司Agent配置Map,key为AgentType枚举值(固定六个Agent类型key), 值为Agent配置对象") +@Data +@Builder +public class AgentConfigListResponse { + + @Schema(description = "Agent配置映射,key为AgentType枚举类型code,value为Agent配置信息。AgentType类型code:RedTeamAgent(红黑对抗)、AnalyzerAgent(分析)、IndustryClassifierAgent(行业识别)、HallucinationDetectorAgent(幻觉检测)、ReviewerAgent(审查)、regenerator(再次生成)") + private Map agentConfigs; + + @Schema(description = "公司所有Prompt配置列表") + private List promptConfigs; + + @Schema(description = "公司所有Model配置列表") + private List modelConfigs; +} \ No newline at end of file diff --git a/src/main/java/com/labelsys/backend/dto/response/SysConfigResponse.java b/src/main/java/com/labelsys/backend/dto/response/SysConfigResponse.java index f0ff9c8..7835e57 100644 --- a/src/main/java/com/labelsys/backend/dto/response/SysConfigResponse.java +++ b/src/main/java/com/labelsys/backend/dto/response/SysConfigResponse.java @@ -1,14 +1,17 @@ package com.labelsys.backend.dto.response; import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotBlank; + import java.time.LocalDateTime; @Schema(description = "系统配置响应") public record SysConfigResponse( @Schema(description = "配置ID", example = "191000000000000501") Long id, - @Schema(description = "配置类型", example = "MODEL") String configType, + @Schema(description = "配置类型: MODEL-大模型配置、PROMPT-提示词配置、SYSTEM-其他配置项", example = "MODEL") String configType, @Schema(description = "配置名称", example = "qwen-plus-extract") String configName, - @Schema(description = "配置值", example = "{\"modelName\":\"qwen-plus\",\"modelUrl\":\"https://dashscope.aliyuncs.com/compatible-mode/v1\",\"apiKey\":\"sk-demo1234\"}") String configValue, + @Schema(description = "配置值, 当configType为model时候,configValue为大模型配置的json字符串,configType为其他值时,configValue为普通字符串", + example = "{\"modelName\":\"qwen-plus\",\"modelUrl\":\"https://dashscope.aliyuncs.com/compatible-mode/v1\",\"apiKey\":\"sk-demo1234\"}") String configValue, @Schema(description = "配置状态", example = "ENABLED") String status, @Schema(description = "创建人ID", example = "191000000000000021") Long creatorId, @Schema(description = "创建时间", example = "2026-04-27T09:50:00") LocalDateTime createdAt, diff --git a/src/main/java/com/labelsys/backend/entity/AnnotationAgentConfig.java b/src/main/java/com/labelsys/backend/entity/AnnotationAgentConfig.java new file mode 100644 index 0000000..85c0d99 --- /dev/null +++ b/src/main/java/com/labelsys/backend/entity/AnnotationAgentConfig.java @@ -0,0 +1,160 @@ +package com.labelsys.backend.entity; + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableField; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler; +import com.fasterxml.jackson.databind.JsonNode; +import com.labelsys.backend.util.SM4Util; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.time.LocalDateTime; + +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +@TableName("annotation_agent_configs") +public class AnnotationAgentConfig { + @TableId(type = IdType.AUTO) + private Long id; + + @TableField("company_id") + private Long companyId; // 所属公司ID,来源于annotation_task.company_id + + @TableField("agent_type") + private String agentType; // Agent角色: extract(抽取/Analyzer+Regenerator) / verify(校验/Reviewer) / industry(行业识别) / hallucination(幻觉检测) + + @TableField("model_config_id") + private Long modelConfigId; // 模型配置来源ID,外键引用sys_config.id,任务启动时从sys_config(config_type=MODEL, status=ENABLED)获取 + + @TableField("model_name") + private String modelName; // 模型名称快照,如: qwen-max / glm-4 / claude-3-opus + + @TableField("model_url") + private String modelUrl; // 模型调用地址快照 + + @TableField("model_api_key") + private String modelApiKey; // 模型调用密钥快照(需要加密存储) + + @TableField("prompt_config_id") + private Long promptConfigId; // Prompt配置来源ID,外键引用sys_config.id,任务启动时从sys_config(config_type=PROMPT, status=ENABLED)获取 + + @TableField("prompt_text") + private String promptText; // Prompt文本快照,任务执行期间实际使用的提示词 + + @TableField(value = "config_snapshot", typeHandler = JacksonTypeHandler.class) + private JsonNode configSnapshot; // 完整配置快照JSON,包含temperature、max_tokens等运行时参数 + + @TableField("created_at") + private LocalDateTime createdAt; // 快照创建时间 + + @TableField("llm_type") + private String llmType; // LLM类型,如:text, audio, video, image等,默认为'text' + + // Getter和Setter方法 + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public Long getCompanyId() { + return companyId; + } + + public void setCompanyId(Long companyId) { + this.companyId = companyId; + } + + public String getAgentType() { + return agentType; + } + + public void setAgentType(String agentType) { + this.agentType = agentType; + } + + public Long getModelConfigId() { + return modelConfigId; + } + + public void setModelConfigId(Long modelConfigId) { + this.modelConfigId = modelConfigId; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(String modelName) { + this.modelName = modelName; + } + + public String getModelUrl() { + return modelUrl; + } + + public void setModelUrl(String modelUrl) { + this.modelUrl = modelUrl; + } + + public String getModelApiKey() { + if (this.modelApiKey != null) { + return SM4Util.decryptSafe(this.modelApiKey); // 解密返回 + } + return null; + } + + public void setModelApiKey(String modelApiKey) { + if (modelApiKey != null) { + this.modelApiKey = SM4Util.encrypt(modelApiKey); // 加密存储 + } + } + + public Long getPromptConfigId() { + return promptConfigId; + } + + public void setPromptConfigId(Long promptConfigId) { + this.promptConfigId = promptConfigId; + } + + public String getPromptText() { + return promptText; + } + + public void setPromptText(String promptText) { + this.promptText = promptText; + } + + public JsonNode getConfigSnapshot() { + return configSnapshot; + } + + public void setConfigSnapshot(JsonNode configSnapshot) { + this.configSnapshot = configSnapshot; + } + + public LocalDateTime getCreatedAt() { + return createdAt; + } + + public void setCreatedAt(LocalDateTime createdAt) { + this.createdAt = createdAt; + } + + public String getLlmType() { + return llmType; + } + + public void setLlmType(String llmType) { + this.llmType = llmType; + } +} \ No newline at end of file diff --git a/src/main/java/com/labelsys/backend/enums/AgentType.java b/src/main/java/com/labelsys/backend/enums/AgentType.java new file mode 100644 index 0000000..0e6883f --- /dev/null +++ b/src/main/java/com/labelsys/backend/enums/AgentType.java @@ -0,0 +1,43 @@ +package com.labelsys.backend.enums; + +import io.swagger.v3.oas.annotations.media.Schema; + +@Schema(description = "Agent类型枚举") +public enum AgentType { + @Schema(description = "红黑对抗") + RED_TEAM("RedTeamAgent"), + + @Schema(description = "分析") + ANALYZER("AnalyzerAgent"), + + @Schema(description = "行业识别") + INDUSTRY_CLASSIFIER("IndustryClassifierAgent"), + + @Schema(description = "幻觉检测") + HALLUCINATION_DETECTOR("HallucinationDetectorAgent"), + + @Schema(description = "审查") + REVIEWER("ReviewerAgent"), + + @Schema(description = "再次生成") + REGENERATOR("regenerator"); + + private final String code; + + AgentType(String code) { + this.code = code; + } + + public String getCode() { + return code; + } + + public static boolean isValid(String value) { + for (AgentType type : AgentType.values()) { + if (type.name().equalsIgnoreCase(value) || type.getCode().equals(value)) { + return true; + } + } + return false; + } +} \ No newline at end of file diff --git a/src/main/java/com/labelsys/backend/enums/MultiModalType.java b/src/main/java/com/labelsys/backend/enums/MultiModalType.java new file mode 100644 index 0000000..e6096d7 --- /dev/null +++ b/src/main/java/com/labelsys/backend/enums/MultiModalType.java @@ -0,0 +1,37 @@ +package com.labelsys.backend.enums; + +import io.swagger.v3.oas.annotations.media.Schema; + +@Schema(description = "多模态类型枚举") +public enum MultiModalType { + @Schema(description = "文本") + TEXT("text"), + + @Schema(description = "音频") + AUDIO("audio"), + + @Schema(description = "视频") + VIDEO("video"), + + @Schema(description = "图片") + IMAGE("image"); + + private final String code; + + MultiModalType(String code) { + this.code = code; + } + + public String getCode() { + return code; + } + + public static boolean isValid(String value) { + for (MultiModalType type : MultiModalType.values()) { + if (type.name().equalsIgnoreCase(value) || type.getCode().equals(value)) { + return true; + } + } + return false; + } +} \ No newline at end of file diff --git a/src/main/java/com/labelsys/backend/mapper/AnnotationAgentConfigMapper.java b/src/main/java/com/labelsys/backend/mapper/AnnotationAgentConfigMapper.java new file mode 100644 index 0000000..d662f16 --- /dev/null +++ b/src/main/java/com/labelsys/backend/mapper/AnnotationAgentConfigMapper.java @@ -0,0 +1,11 @@ +package com.labelsys.backend.mapper; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import com.labelsys.backend.entity.AnnotationAgentConfig; +import org.apache.ibatis.annotations.Mapper; + +@Mapper +public interface AnnotationAgentConfigMapper extends BaseMapper { + // 基于BaseMapper,已提供基本的CRUD操作 + // 如需特殊查询,可在此添加自定义方法 +} \ No newline at end of file diff --git a/src/main/java/com/labelsys/backend/service/AnnotationAgentConfigService.java b/src/main/java/com/labelsys/backend/service/AnnotationAgentConfigService.java new file mode 100644 index 0000000..764a61c --- /dev/null +++ b/src/main/java/com/labelsys/backend/service/AnnotationAgentConfigService.java @@ -0,0 +1,299 @@ +package com.labelsys.backend.service; + +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.baomidou.mybatisplus.core.toolkit.Wrappers; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.labelsys.backend.common.ResultCode; +import com.labelsys.backend.common.exception.BusinessException; +import com.labelsys.backend.context.LoginUser; +import com.labelsys.backend.dto.LlmConfigModel; +import com.labelsys.backend.dto.request.SaveAgentConfigRequest; +import com.labelsys.backend.dto.response.AgentConfigInfoResponse; +import com.labelsys.backend.dto.response.AgentConfigListResponse; +import com.labelsys.backend.entity.AnnotationAgentConfig; +import com.labelsys.backend.entity.SysConfig; +import com.labelsys.backend.enums.AgentType; +import com.labelsys.backend.mapper.AnnotationAgentConfigMapper; +import com.labelsys.backend.util.SM4Util; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +@Slf4j +@Service +@Transactional +@RequiredArgsConstructor +public class AnnotationAgentConfigService { + + private final AnnotationAgentConfigMapper agentConfigMapper; + + private final SysConfigService sysConfigService; + + // 保存多个Agent配置 + // 保存多个Agent配置 + public AgentConfigListResponse saveAgentConfigs(LoginUser user, SaveAgentConfigRequest request) { + // 遍历所有Agent配置项 + for (Map.Entry entry : request.getAgentConfigs().entrySet()) { + String agentType = entry.getKey(); + SaveAgentConfigRequest.AgentConfigItem configItem = entry.getValue(); + + // 验证Agent类型是否有效 + if (!AgentType.isValid(agentType)) { + throw new BusinessException(ResultCode.BAD_REQUEST, "无效的Agent类型: " + agentType); + } + + // 验证模型配置是否存在且属于该公司 + SysConfig modelConfig = sysConfigService.getById(configItem.getModelConfigId()); + if (modelConfig == null || modelConfig.getCompanyId() == null || !modelConfig.getCompanyId() + .equals(user.companyId())) { + throw new BusinessException(ResultCode.BAD_REQUEST, "模型配置不存在或不属于当前公司"); + } + + // 验证提示词配置是否存在且属于该公司(如果提供了的话) + if (configItem.getPromptConfigId() != null) { + SysConfig promptConfig = sysConfigService.getById(configItem.getPromptConfigId()); + if (promptConfig == null || promptConfig.getCompanyId() == null || !promptConfig.getCompanyId() + .equals(user.companyId())) { + throw new BusinessException(ResultCode.BAD_REQUEST, "提示词配置不存在或不属于当前公司"); + } + } + + // 从模型配置中提取API密钥 + String apiKey = extractApiKeyFromConfig(modelConfig.getConfigValue()); + + // 检查是否已存在配置 + LambdaQueryWrapper wrapper = Wrappers.lambdaQuery() + .eq(AnnotationAgentConfig::getCompanyId, user.companyId()) + .eq(AnnotationAgentConfig::getAgentType, agentType); + + AnnotationAgentConfig existing = agentConfigMapper.selectOne(wrapper); + if (existing != null) { + // 更新现有配置 + existing.setModelConfigId(configItem.getModelConfigId()); + existing.setPromptConfigId(configItem.getPromptConfigId()); + // 从模型配置中获取其他字段 + existing.setModelName(modelConfig.getConfigName()); + existing.setModelUrl(extractModelUrlFromConfig(modelConfig.getConfigValue())); + existing.setModelApiKey(apiKey); // 保存API密钥 + existing.setLlmType(extractLlmTypeFromConfig(modelConfig.getConfigValue())); + existing.setCreatedAt(LocalDateTime.now()); + + agentConfigMapper.updateById(existing); + } else { + // 创建新配置 + AnnotationAgentConfig config = AnnotationAgentConfig.builder() + .companyId(user.companyId()) + .agentType(agentType) + .modelConfigId(configItem.getModelConfigId()) + .modelName(modelConfig.getConfigName()) + .modelUrl(extractModelUrlFromConfig(modelConfig.getConfigValue())) + .modelApiKey(apiKey) // 保存API密钥 + .promptConfigId(configItem.getPromptConfigId()) + .promptText(getPromptText(configItem.getPromptConfigId())) + .llmType(extractLlmTypeFromConfig(modelConfig.getConfigValue())) + .createdAt(LocalDateTime.now()) + .build(); + + agentConfigMapper.insert(config); + } + } + + // 返回更新后的完整配置列表 + return getAgentConfigsForCompany(user.companyId()); + } + + // 从模型配置值中提取API密钥 + private String extractApiKeyFromConfig(String configValue) { + try { + ObjectMapper objectMapper = new ObjectMapper(); + LlmConfigModel llmConfig = objectMapper.readValue(configValue, LlmConfigModel.class); + return llmConfig != null ? llmConfig.getApiKey() : null; + } catch (Exception e) { + return null; + } + } + + + // 从模型配置值中提取模型URL + private String extractModelUrlFromConfig(String configValue) { + try { + ObjectMapper objectMapper = new ObjectMapper(); + LlmConfigModel llmConfig = objectMapper.readValue(configValue, LlmConfigModel.class); + return llmConfig != null ? llmConfig.getModelUrl() : null; + } catch (Exception e) { + return null; + } + } + + // 从模型配置值中提取LLM类型 + private String extractLlmTypeFromConfig(String configValue) { + try { + ObjectMapper objectMapper = new ObjectMapper(); + LlmConfigModel llmConfig = objectMapper.readValue(configValue, LlmConfigModel.class); + return llmConfig != null ? llmConfig.getLlmType() : null; + } catch (Exception e) { + return null; + } + } + + // 获取提示词文本 + private String getPromptText(Long promptConfigId) { + if (promptConfigId == null) { + return null; + } + SysConfig promptConfig = sysConfigService.getById(promptConfigId); + return promptConfig != null ? promptConfig.getConfigValue() : null; + } + + // 获取特定Agent类型的配置实体 + public AnnotationAgentConfig getAgentConfigEntity(Long companyId, String agentType) { + LambdaQueryWrapper wrapper = Wrappers.lambdaQuery() + .eq(AnnotationAgentConfig::getCompanyId, companyId) + .eq(AnnotationAgentConfig::getAgentType, agentType); + + return agentConfigMapper.selectOne(wrapper); + } + + // 获取公司所有Agent配置的完整信息 + public AgentConfigListResponse getAgentConfigsForCompany(Long companyId) { + // 获取所有固定Agent类型 + Map agentConfigs = new HashMap<>(); + + // 获取该公司所有的Prompt配置 + List promptConfigs = getCompanyPromptConfigs(companyId); + + // 获取该公司所有的Model配置 + List modelConfigs = getCompanyModelConfigs(companyId); + + // 遍历所有Agent类型 + for (AgentType agentType : AgentType.values()) { + String agentTypeCode = agentType.getCode(); + + // 查询该Agent类型的配置 + AnnotationAgentConfig agentConfig = getAgentConfigByType(companyId, agentTypeCode); + + if (agentConfig != null) { + // 如果存在配置,构建完整配置信息 + AgentConfigInfoResponse.AgentConfigInfoResponseBuilder builder = AgentConfigInfoResponse.builder() + .id(agentConfig.getId()) + .agentType(agentConfig.getAgentType()) + .createdAt(agentConfig.getCreatedAt()); + + // 设置模型配置信息 + if (agentConfig.getModelConfigId() != null) { + SysConfig modelConfig = sysConfigService.getById(agentConfig.getModelConfigId()); + if (modelConfig != null) { + AgentConfigInfoResponse.ModelConfigInfo modelConfigInfo = AgentConfigInfoResponse.ModelConfigInfo.builder() + .id(modelConfig.getId()) + .modelName(modelConfig.getConfigName()) + .build(); + + // 解析模型配置值 + try { + ObjectMapper objectMapper = new ObjectMapper(); + LlmConfigModel llmConfig = objectMapper.readValue(modelConfig.getConfigValue(), + LlmConfigModel.class); + if (llmConfig != null) { + modelConfigInfo.setModelUrl(llmConfig.getModelUrl()); + modelConfigInfo.setApiKey(SM4Util.decryptSafe(llmConfig.getApiKey())); + modelConfigInfo.setLlmType(llmConfig.getLlmType()); + } + } catch (Exception e) { + // 解析失败,忽略 + } + + builder.modelConfig(modelConfigInfo); + } + } + + // 设置Prompt配置信息 + if (agentConfig.getPromptConfigId() != null) { + SysConfig promptConfig = sysConfigService.getById(agentConfig.getPromptConfigId()); + if (promptConfig != null) { + AgentConfigInfoResponse.PromptConfigInfo promptConfigInfo = AgentConfigInfoResponse.PromptConfigInfo.builder() + .id(promptConfig.getId()) + .promptText(promptConfig.getConfigValue()) + .build(); + builder.promptConfig(promptConfigInfo); + } + } + + agentConfigs.put(agentTypeCode, builder.build()); + } else { + // 如果没有配置,返回空的配置信息 + agentConfigs.put(agentTypeCode, AgentConfigInfoResponse.builder() + .agentType(agentTypeCode) + .build()); + } + } + + return AgentConfigListResponse.builder() + .agentConfigs(agentConfigs) + .promptConfigs(promptConfigs) + .modelConfigs(modelConfigs) + .build(); + } + + // 获取公司所有Prompt配置 + private List getCompanyPromptConfigs(Long companyId) { + List configs = sysConfigService.getCompanyConfigsByType(companyId, "PROMPT"); + return configs.stream().map(config -> AgentConfigInfoResponse.PromptConfigInfo.builder() + .id(config.getId()) + .promptText(config.getConfigValue()) + .build()).collect(Collectors.toList()); + } + + // 获取公司所有Model配置 + private List getCompanyModelConfigs(Long companyId) { + List configs = sysConfigService.getCompanyConfigsByType(companyId, "MODEL"); + return configs.stream().map(config -> { + AgentConfigInfoResponse.ModelConfigInfo.ModelConfigInfoBuilder builder = + AgentConfigInfoResponse.ModelConfigInfo.builder() + .id(config.getId()) + .modelName(config.getConfigName()); + + // 解析模型配置值 + try { + ObjectMapper objectMapper = new ObjectMapper(); + LlmConfigModel llmConfig = objectMapper.readValue(config.getConfigValue(), LlmConfigModel.class); + if (llmConfig != null) { + builder.modelUrl(llmConfig.getModelUrl()) + .apiKey(SM4Util.decryptSafe(llmConfig.getApiKey())) + .llmType(llmConfig.getLlmType()); + } + } catch (Exception e) { + // 解析失败,忽略 + } + + return builder.build(); + }).collect(Collectors.toList()); + } + + // 根据类型获取Agent配置 + private AnnotationAgentConfig getAgentConfigByType(Long companyId, String agentType) { + LambdaQueryWrapper wrapper = Wrappers.lambdaQuery() + .eq(AnnotationAgentConfig::getCompanyId, companyId) + .eq(AnnotationAgentConfig::getAgentType, agentType); + + return agentConfigMapper.selectOne(wrapper); + } + + // private AgentConfigDetailResponse toDetailResponse(AnnotationAgentConfig config) { + // return AgentConfigDetailResponse.builder() + // .id(config.getId()) + // .agentType(config.getAgentType()) + // .llmType(config.getLlmType()) + // .modelName(config.getModelName()) + // .modelUrl(config.getModelUrl()) + // .promptText(config.getPromptText()) + // .configSnapshot(config.getConfigSnapshot()) + // .build(); + // } +} \ No newline at end of file diff --git a/src/main/java/com/labelsys/backend/service/AnnotationResultService.java b/src/main/java/com/labelsys/backend/service/AnnotationResultService.java index 522aaca..c420df4 100644 --- a/src/main/java/com/labelsys/backend/service/AnnotationResultService.java +++ b/src/main/java/com/labelsys/backend/service/AnnotationResultService.java @@ -266,7 +266,8 @@ public class AnnotationResultService { List updatedQaRecords = qaContent.records().stream() .map(record -> { String mergedAnswer = request.mergedAnswers().get(record.id()); - String reviewComment = request.reviewComments() != null ? request.reviewComments().get(record.id()) : null; + String reviewComment = + request.reviewComments() != null ? request.reviewComments().get(record.id()) : null; if (mergedAnswer != null || reviewComment != null) { return new QaContent.QaRecord( record.id(), @@ -364,7 +365,7 @@ public class AnnotationResultService { return objectMapper.readValue(jsonContent, new TypeReference() { }); } catch (Exception e) { - log.warn("Failed to load qa content, returning empty content. resultId={}, filePath={}, error={}", + log.warn("Failed to load qa content, returning empty content. resultId={}, filePath={}, error={}", result.getId(), result.getQaContentFilePath(), e.getMessage()); return new QaContent(null, null, List.of(), null); } @@ -384,7 +385,7 @@ public class AnnotationResultService { return objectMapper.readValue(jsonContent, new TypeReference() { }); } catch (Exception e) { - log.warn("Failed to load diff summary, returning empty content. resultId={}, filePath={}, error={}", + log.warn("Failed to load diff summary, returning empty content. resultId={}, filePath={}, error={}", result.getId(), result.getDiffSummaryFilePath(), e.getMessage()); return new DiffContent(null, null, List.of(), null); } @@ -487,9 +488,9 @@ public class AnnotationResultService { List records, Metadata metadata ) { - private record QaRecord(String id, Long batchId, String question, String answer, - Boolean requiresReview, SourceSegments sourceSegments, - String questionCategory, Scores scores, String reviewComment) { + private record QaRecord(String id, Long batchId, String question, String answer, + Boolean requiresReview, SourceSegments sourceSegments, + String questionCategory, Scores scores, String reviewComment) { } private record SourceSegments(String segment, Integer chunkIndex, String chunkTitle, String chunkContent) { diff --git a/src/main/java/com/labelsys/backend/service/SysConfigService.java b/src/main/java/com/labelsys/backend/service/SysConfigService.java index b8e17ef..80baddc 100644 --- a/src/main/java/com/labelsys/backend/service/SysConfigService.java +++ b/src/main/java/com/labelsys/backend/service/SysConfigService.java @@ -2,9 +2,12 @@ package com.labelsys.backend.service; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.labelsys.backend.common.ResultCode; import com.labelsys.backend.common.exception.BusinessException; import com.labelsys.backend.context.LoginUser; +import com.labelsys.backend.dto.LlmConfigModel; import com.labelsys.backend.dto.common.PageResult; import com.labelsys.backend.dto.request.SaveSysConfigRequest; import com.labelsys.backend.dto.request.SysConfigPageQuery; @@ -12,9 +15,9 @@ import com.labelsys.backend.dto.request.UpdateSysConfigRequest; import com.labelsys.backend.dto.response.SysConfigResponse; import com.labelsys.backend.entity.SysConfig; import com.labelsys.backend.enums.ConfigType; -import com.labelsys.backend.enums.UserRole; import com.labelsys.backend.mapper.SysConfigMapper; import com.labelsys.backend.util.IdGenerator; +import com.labelsys.backend.util.SM4Util; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; @@ -28,8 +31,8 @@ import java.util.List; @RequiredArgsConstructor public class SysConfigService { - private final SysConfigMapper sysConfigMapper; - private final DataPermissionService dataPermissionService; + private final SysConfigMapper sysConfigMapper; + //private final DataPermissionService dataPermissionService; @Transactional public SysConfig saveConfig(LoginUser currentUser, SaveSysConfigRequest request) { @@ -40,12 +43,19 @@ public class SysConfigService { if (existing != null) { throw new BusinessException(ResultCode.CONFLICT, "配置名称已存在"); } + + String processedConfigValue = request.configValue(); + // 如果是model类型,对apiKey进行加密处理 + if (ConfigType.MODEL.name().equalsIgnoreCase(request.configType())) { + processedConfigValue = processModelConfigValue(processedConfigValue, true); + } + SysConfig config = SysConfig.builder() .id(IdGenerator.nextId()) .companyId(currentUser.companyId()) .configType(request.configType()) .configName(request.configName()) - .configValue(request.configValue()) + .configValue(processedConfigValue) .status(request.status()) .creatorId(currentUser.userId()) .creatorRole(currentUser.role().name()) @@ -77,7 +87,12 @@ public class SysConfigService { existing.setConfigType(request.configType()); } if (StringUtils.hasText(request.configValue())) { - existing.setConfigValue(request.configValue()); + // 如果是model类型,对apiKey进行加密处理 + if (ConfigType.MODEL.name().equalsIgnoreCase(request.configType())) { + existing.setConfigValue(processModelConfigValue(request.configValue(), true)); + } else { + existing.setConfigValue(request.configValue()); + } } if (StringUtils.hasText(request.status())) { existing.setStatus(request.status()); @@ -95,27 +110,10 @@ public class SysConfigService { } } - public SysConfigResponse getConfig(LoginUser currentUser, Long configId) { - try { - SysConfig config = getConfigEntity(currentUser, configId); - if (!dataPermissionService.canAccessCreator(currentUser, config.getCreatorId(), - UserRole.valueOf(config.getCreatorRole()))) { - throw new BusinessException(ResultCode.FORBIDDEN, "无权访问配置"); - } - return toResponse(config); - } catch (BusinessException e) { - throw e; - } catch (Exception e) { - log.error("getConfig failed, companyId={}, userId={}, configId={}, error={}", - currentUser.companyId(), currentUser.userId(), configId, e.getMessage(), e); - throw e; - } - } - public PageResult pageConfigs(LoginUser currentUser, SysConfigPageQuery query) { try { - List allowedRoles = dataPermissionService.getAllowedRoles(currentUser); - boolean shouldFilterByUserId = dataPermissionService.shouldFilterByUserId(currentUser); + // List allowedRoles = dataPermissionService.getAllowedRoles(currentUser); + // boolean shouldFilterByUserId = dataPermissionService.shouldFilterByUserId(currentUser); LambdaQueryWrapper wrapper = new LambdaQueryWrapper() .eq(SysConfig::getCompanyId, currentUser.companyId()) @@ -123,11 +121,11 @@ public class SysConfigService { .eq(StringUtils.hasText(query.status()), SysConfig::getStatus, query.status()) .like(StringUtils.hasText(query.configName()), SysConfig::getConfigName, query.configName()); - if (shouldFilterByUserId) { - wrapper.eq(SysConfig::getCreatorId, currentUser.userId()); - } else if (!allowedRoles.isEmpty()) { - wrapper.in(SysConfig::getCreatorRole, allowedRoles); - } + // if (shouldFilterByUserId) { + // wrapper.eq(SysConfig::getCreatorId, currentUser.userId()); + // } else if (!allowedRoles.isEmpty()) { + // wrapper.in(SysConfig::getCreatorRole, allowedRoles); + // } wrapper.orderByDesc(SysConfig::getCreatedAt); @@ -173,4 +171,84 @@ public class SysConfigService { throw new BusinessException(ResultCode.BAD_REQUEST, "配置类型非法"); } } + + /** + * 处理模型配置值,加密或解密API密钥 + * + * @param configValue 配置值 + * @param encrypt true表示加密,false表示解密 + * @return 处理后的配置值 + */ + private String processModelConfigValue(String configValue, boolean encrypt) { + try { + ObjectMapper objectMapper = new ObjectMapper(); + LlmConfigModel configModel = objectMapper.readValue(configValue, LlmConfigModel.class); + + if (configModel.getApiKey() != null && !configModel.getApiKey().isEmpty()) { + if (encrypt) { + // 加密API密钥 + configModel.setApiKey(SM4Util.encrypt(configModel.getApiKey())); + } else { + // 解密API密钥 + configModel.setApiKey(SM4Util.decryptSafe(configModel.getApiKey())); + } + } + + return objectMapper.writeValueAsString(configModel); + } catch (JsonProcessingException e) { + log.error("处理模型配置值失败: {}", e.getMessage(), e); + throw new BusinessException(ResultCode.BAD_REQUEST, "模型配置格式错误"); + } + } + + /** + * 根据ID获取配置实体 + */ + public SysConfig getById(Long configId) { + return sysConfigMapper.selectById(configId); + } + + /** + * 获取配置详情(根据配置ID查询数据库判断类型并返回SysConfigResponse格式) + */ + public SysConfigResponse getConfigDetail(LoginUser currentUser, Long configId) { + SysConfig config = getConfigEntity(currentUser, configId); + + // 如果是模型配置,我们需要在返回前处理API密钥解密 + if (ConfigType.MODEL.name().equalsIgnoreCase(config.getConfigType())) { + if (config.getConfigValue() != null) { + try { + ObjectMapper objectMapper = new ObjectMapper(); + LlmConfigModel model = objectMapper.readValue(config.getConfigValue(), LlmConfigModel.class); + if (model != null && model.getApiKey() != null) { + // 解密API密钥并更新配置值 + String decryptedApiKey = SM4Util.decryptSafe(model.getApiKey()); + model.setApiKey(decryptedApiKey); + + // 将更新后的对象转回JSON字符串,但这次API密钥是解密的 + String updatedConfigValue = objectMapper.writeValueAsString(model); + config.setConfigValue(updatedConfigValue); + } + } catch (JsonProcessingException e) { + log.error("解析模型配置失败: {}", e.getMessage(), e); + throw new BusinessException(ResultCode.BAD_REQUEST, "模型配置格式错误"); + } + } + } + + // 总是返回SysConfigResponse对象 + return toResponse(config); + } + + /** + * 获取公司特定类型的配置实体列表 + */ + public List getCompanyConfigsByType(Long companyId, String configType) { + LambdaQueryWrapper wrapper = + new LambdaQueryWrapper() + .eq(SysConfig::getCompanyId, companyId) + .eq(SysConfig::getConfigType, configType); + + return sysConfigMapper.selectList(wrapper); + } } \ No newline at end of file diff --git a/src/main/java/com/labelsys/backend/util/SM4Util.java b/src/main/java/com/labelsys/backend/util/SM4Util.java index 1c6d144..fc057cd 100644 --- a/src/main/java/com/labelsys/backend/util/SM4Util.java +++ b/src/main/java/com/labelsys/backend/util/SM4Util.java @@ -5,24 +5,95 @@ import org.bouncycastle.crypto.modes.CBCBlockCipher; import org.bouncycastle.crypto.paddings.PaddedBufferedBlockCipher; import org.bouncycastle.crypto.params.KeyParameter; import org.bouncycastle.crypto.params.ParametersWithIV; +import org.bouncycastle.jce.provider.BouncyCastleProvider; import java.nio.charset.StandardCharsets; -import java.util.Base64; +import java.security.Security; +import java.util.Arrays; +/** + * SM4 国密加解密工具类 + * SM4-CBC 模式,密钥/IV 从环境变量读取(hex编码,32字符) + */ public class SM4Util { - private static final int KEY_LENGTH = 16; - private static final int IV_LENGTH = 16; + private static final String ALGORITHM = "SM4"; + private static final String TRANSFORMATION = "SM4/CBC/NoPadding"; + private static final int BLOCK_SIZE = 16; + private static final String SM4_SECRET_KEY = "a5c9b98d8042ff99a5befedd1bcdc78a"; + private static final String SM4_IV = "83e3d80962fea60369fb7ebaac9b2285"; + + static { + Security.addProvider(new BouncyCastleProvider()); + } + + // ---- 密钥 / IV ---- + + private static byte[] getKey() { + //String hex = System.getenv("SM4_SECRET_KEY"); + if (SM4_SECRET_KEY == null || SM4_SECRET_KEY.isEmpty()) + throw new SM4CryptoException("环境变量 SM4_SECRET_KEY 未配置"); + byte[] key = hexToBytes(SM4_SECRET_KEY); + if (key.length != 16) + throw new SM4CryptoException("SM4 密钥长度必须为16字节,当前为 " + key.length + " 字节"); + return key; + } + + private static byte[] getIV() { + //String hex = System.getenv("SM4_IV"); + if (SM4_IV != null && !SM4_IV.isEmpty()) { + byte[] iv = hexToBytes(SM4_IV); + if (iv.length != 16) + throw new SM4CryptoException("SM4 IV 长度必须为16字节,当前为 " + iv.length + " 字节"); + return iv; + } + return new byte[16]; // 默认全零 IV + } + + // ---- PKCS7 填充 ---- + + private static byte[] pad(byte[] data) { + int n = BLOCK_SIZE - (data.length % BLOCK_SIZE); + byte[] out = Arrays.copyOf(data, data.length + n); + Arrays.fill(out, data.length, out.length, (byte) n); + return out; + } + + private static byte[] unpad(byte[] data) { + if (data.length == 0) + throw new SM4CryptoException("数据为空,无法去填充"); + int n = data[data.length - 1] & 0xFF; + if (n < 1 || n > 16) + throw new SM4CryptoException("无效的填充长度: " + n); + for (int i = data.length - n; i < data.length; i++) { + if ((data[i] & 0xFF) != n) + throw new SM4CryptoException("填充校验失败"); + } + return Arrays.copyOf(data, data.length - n); + } + + // ---- 加解密方法 ---- + + /** + * SM4-CBC 加密,返回 hex 密文 + */ public static String encrypt(String plainText, String key, String iv) { try { - validateKey(key); - validateIV(iv); + if (key == null || key.length() != 32) { + throw new IllegalArgumentException("SM4密钥长度必须为32位十六进制字符"); + } + if (iv == null || iv.length() != 32) { + throw new IllegalArgumentException("SM4 IV长度必须为32位十六进制字符"); + } - byte[] keyBytes = key.getBytes(StandardCharsets.UTF_8); - byte[] ivBytes = iv.getBytes(StandardCharsets.UTF_8); + byte[] keyBytes = hexToBytes(key); + byte[] ivBytes = hexToBytes(iv); byte[] plainBytes = plainText.getBytes(StandardCharsets.UTF_8); + // 手动进行PKCS7填充,与Python版本保持一致 + byte[] paddedData = pkcs7Pad(plainBytes, 16); + SM4Engine engine = new SM4Engine(); CBCBlockCipher blockCipher = new CBCBlockCipher(engine); PaddedBufferedBlockCipher cipher = new PaddedBufferedBlockCipher(blockCipher); @@ -31,24 +102,33 @@ public class SM4Util { ParametersWithIV parametersWithIV = new ParametersWithIV(keyParameter, ivBytes); cipher.init(true, parametersWithIV); - byte[] encrypted = new byte[cipher.getOutputSize(plainBytes.length)]; - int len = cipher.processBytes(plainBytes, 0, plainBytes.length, encrypted, 0); + byte[] encrypted = new byte[cipher.getOutputSize(paddedData.length)]; + int len = cipher.processBytes(paddedData, 0, paddedData.length, encrypted, 0); len += cipher.doFinal(encrypted, len); - return Base64.getEncoder().encodeToString(encrypted); + byte[] result = new byte[len]; + System.arraycopy(encrypted, 0, result, 0, len); + return bytesToHex(result); } catch (Exception e) { throw new RuntimeException("SM4加密失败", e); } } + /** + * SM4-CBC 解密,传入 hex 密文 + */ public static String decrypt(String cipherText, String key, String iv) { try { - validateKey(key); - validateIV(iv); + if (key == null || key.length() != 32) { + throw new IllegalArgumentException("SM4密钥长度必须为32位十六进制字符"); + } + if (iv == null || iv.length() != 32) { + throw new IllegalArgumentException("SM4 IV长度必须为32位十六进制字符"); + } - byte[] keyBytes = key.getBytes(StandardCharsets.UTF_8); - byte[] ivBytes = iv.getBytes(StandardCharsets.UTF_8); - byte[] cipherBytes = Base64.getDecoder().decode(cipherText); + byte[] keyBytes = hexToBytes(key); + byte[] ivBytes = hexToBytes(iv); + byte[] cipherBytes = hexToBytes(cipherText); SM4Engine engine = new SM4Engine(); CBCBlockCipher blockCipher = new CBCBlockCipher(engine); @@ -62,40 +142,111 @@ public class SM4Util { int len = cipher.processBytes(cipherBytes, 0, cipherBytes.length, decrypted, 0); len += cipher.doFinal(decrypted, len); - return new String(decrypted, 0, len, StandardCharsets.UTF_8); + // 手动去PKCS7填充,与Python版本保持一致 + byte[] unpadded = pkcs7Unpad(Arrays.copyOf(decrypted, len)); + + return new String(unpadded, StandardCharsets.UTF_8); } catch (Exception e) { throw new RuntimeException("SM4解密失败", e); } } - public static String encrypt(String plainText, String key) { - return encrypt(plainText, key, key.substring(0, IV_LENGTH)); + /** + * 使用默认密钥和IV进行加密 + */ + public static String encrypt(String plainText) { + return encrypt(plainText, SM4_SECRET_KEY, SM4_IV); } - public static String decrypt(String cipherText, String key) { - return decrypt(cipherText, key, key.substring(0, IV_LENGTH)); + /** + * 使用默认密钥和IV进行解密 + */ + public static String decrypt(String cipherText) { + return decrypt(cipherText, SM4_SECRET_KEY, SM4_IV); } - private static void validateKey(String key) { - if (key == null || key.length() != KEY_LENGTH) { - throw new IllegalArgumentException("SM4密钥长度必须为16位"); + /** + * 解密,非有效密文时原样返回(兼容未加密旧数据) + */ + public static String decryptSafe(String value) { + try { + return decrypt(value); + } catch (Exception e) { + return value; } } - private static void validateIV(String iv) { - if (iv == null || iv.length() != IV_LENGTH) { - throw new IllegalArgumentException("SM4 IV长度必须为16位"); + // ---- PKCS7 填充/去填充方法 ---- + + private static byte[] pkcs7Pad(byte[] data, int blockSize) { + int paddingLen = blockSize - (data.length % blockSize); + byte[] paddedData = new byte[data.length + paddingLen]; + System.arraycopy(data, 0, paddedData, 0, data.length); + for (int i = 0; i < paddingLen; i++) { + paddedData[data.length + i] = (byte) paddingLen; + } + return paddedData; + } + + private static byte[] pkcs7Unpad(byte[] data) { + if (data == null || data.length == 0) { + throw new RuntimeException("数据为空,无法去填充"); + } + int padLen = data[data.length - 1] & 0xFF; // 转换为无符号整数 + if (padLen < 1 || padLen > 16) { + throw new RuntimeException("无效的填充长度: " + padLen); + } + if (data.length < padLen) { + throw new RuntimeException("数据长度小于填充长度"); + } + // 检查填充是否正确 + for (int i = 0; i < padLen; i++) { + if (data[data.length - padLen + i] != (byte) padLen) { + throw new RuntimeException("填充校验失败"); + } + } + byte[] result = new byte[data.length - padLen]; + System.arraycopy(data, 0, result, 0, data.length - padLen); + return result; + } + + // ---- Hex 工具 ---- + + private static String bytesToHex(byte[] b) { + StringBuilder sb = new StringBuilder(b.length * 2); + for (byte v : b) + sb.append(String.format("%02x", v & 0xFF)); + return sb.toString(); + } + + private static byte[] hexToBytes(String hex) { + byte[] out = new byte[hex.length() / 2]; + for (int i = 0; i < out.length; i++) + out[i] = (byte) ((Character.digit(hex.charAt(2 * i), 16) << 4) + + Character.digit(hex.charAt(2 * i + 1), 16)); + return out; + } + + // ---- 异常 ---- + + public static class SM4CryptoException extends RuntimeException { + public SM4CryptoException(String msg) { + super(msg); + } + + public SM4CryptoException(String msg, Throwable cause) { + super(msg, cause); } } static void main() { - String key = "1234567890123456"; - String iv = "abcdefghijklmnop"; + String key = "a5c9b98d8042ff99a5befedd1bcdc78a"; + String iv = "83e3d80962fea60369fb7ebaac9b2285"; - String plainText = "Hello World!"; + String plainText = "extract-api-key-demo"; - String encrypted = SM4Util.encrypt(plainText, key, iv); - String decrypted = SM4Util.decrypt(encrypted, key, iv); + String encrypted = SM4Util.encrypt(plainText); + String decrypted = SM4Util.decrypt(encrypted); System.out.println("原文: " + plainText); System.out.println("密文: " + encrypted);