diff --git a/sqle/api/controller/v1/sql_manage.go b/sqle/api/controller/v1/sql_manage.go index feb1c899cb..0b972440ff 100644 --- a/sqle/api/controller/v1/sql_manage.go +++ b/sqle/api/controller/v1/sql_manage.go @@ -2,10 +2,13 @@ package v1 import ( "context" + "net/http" dmsV1 "github.com/actiontech/dms/pkg/dms-common/api/dms/v1" "github.com/actiontech/sqle/sqle/api/controller" + "github.com/actiontech/sqle/sqle/dms" "github.com/actiontech/sqle/sqle/locale" + "github.com/actiontech/sqle/sqle/model" "github.com/labstack/echo/v4" ) @@ -66,6 +69,61 @@ type AuditResult struct { ExecutionFailed bool `json:"execution_failed"` } +type RuleDiff struct { + Resolved []*AuditResult `json:"resolved"` + New []*AuditResult `json:"new"` + Unchanged []*AuditResult `json:"unchanged"` +} + +type GetSqlManageRemediationResp struct { + controller.BaseRes + Data *SqlManageRemediation `json:"data"` +} + +type GetSqlManageRemediationOverviewReq struct { + InstanceAuditPlanID string `query:"instance_audit_plan_id" json:"instance_audit_plan_id" valid:"required"` + AuditPlanType string `query:"audit_plan_type" json:"audit_plan_type" valid:"required"` +} + +type GetSqlManageRemediationOverviewResp struct { + controller.BaseRes + Data *SqlManageRemediationOverview `json:"data"` +} + +type SqlManageRemediationOverview struct { + ProjectID string `json:"project_id"` + InstanceAuditPlanID string `json:"instance_audit_plan_id"` + AuditPlanType string `json:"audit_plan_type"` + SqlTotalNum uint64 `json:"sql_total_num"` + FirstScore int32 `json:"first_score"` + LatestScore int32 `json:"latest_score"` + ScoreChange int32 `json:"score_change"` + RemediationRate float64 `json:"remediation_rate"` + RemediationStatusCount *RemediationStatusCounter `json:"remediation_status_count"` + FirstAuditMissingNum uint64 `json:"first_audit_missing_num"` +} + +type RemediationStatusCounter struct { + Resolved uint64 `json:"resolved"` + PartiallyFixed uint64 `json:"partially_fixed"` + Unchanged uint64 `json:"unchanged"` + Deteriorated uint64 `json:"deteriorated"` + NewlyDiscovered uint64 `json:"newly_discovered"` +} + +type SqlManageRemediation struct { + Id uint64 `json:"id"` + SqlFingerprint string `json:"sql_fingerprint"` + Sql string `json:"sql"` + FirstAuditResult []*AuditResult `json:"first_audit_result"` + FirstAuditTime string `json:"first_audit_time"` + LatestAuditResult []*AuditResult `json:"latest_audit_result"` + LatestAuditTime string `json:"latest_audit_time"` + RuleDiff *RuleDiff `json:"rule_diff"` + RemediationStatus string `json:"remediation_status" enums:"resolved,partially_fixed,unchanged,deteriorated,newly_discovered"` + FirstAuditMissing bool `json:"first_audit_missing"` +} + type Source struct { SqlSourceType string `json:"sql_source_type"` SqlSourceDesc string `json:"sql_source_desc"` @@ -90,6 +148,7 @@ type Source struct { // @Param filter_last_audit_start_time_to query string false "last audit start time to" // @Param filter_status query string false "status" Enums(unhandled,solved,ignored,manual_audited) // @Param filter_rule_name query string false "rule name" +// @Param filter_remediation_status query string false "remediation status" Enums(resolved,partially_fixed,unchanged,deteriorated,newly_discovered) // @Param filter_db_type query string false "db type" // @Param fuzzy_search_endpoint query string false "fuzzy search endpoint" // @Param fuzzy_search_schema_name query string false "fuzzy search schema name" @@ -103,6 +162,61 @@ func GetSqlManageList(c echo.Context) error { return nil } +// GetSqlManageRemediationOverviewV1 +// @Summary 获取SQL管控整改概览 +// @Description get sql manage remediation overview +// @Tags SqlManage +// @Id GetSqlManageRemediationOverviewV1 +// @Security ApiKeyAuth +// @Param project_name path string true "project name" +// @Param instance_audit_plan_id query string true "instance audit plan id" +// @Param audit_plan_type query string true "audit plan type" +// @Success 200 {object} v1.GetSqlManageRemediationOverviewResp +// @Router /v1/projects/{project_name}/sql_manages/remediation_overview [get] +func GetSqlManageRemediationOverviewV1(c echo.Context) error { + return getSqlManageRemediationOverviewV1(c) +} + +func getSqlManageRemediationOverviewV1(c echo.Context) error { + req := new(GetSqlManageRemediationOverviewReq) + if err := controller.BindAndValidateReq(c, req); err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + projectUID, err := dms.GetProjectUIDByName(c.Request().Context(), c.Param("project_name")) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + records, err := model.GetStorage().GetManagerSQLListByInstanceAuditPlanAndType(req.InstanceAuditPlanID, req.AuditPlanType) + if err != nil { + return controller.JSONBaseErrorReq(c, err) + } + + overview := model.CalculateSqlManageRemediationOverview(projectUID, req.InstanceAuditPlanID, req.AuditPlanType, records) + return c.JSON(http.StatusOK, &GetSqlManageRemediationOverviewResp{ + BaseRes: controller.NewBaseReq(nil), + Data: &SqlManageRemediationOverview{ + ProjectID: overview.ProjectID, + InstanceAuditPlanID: overview.InstanceAuditPlanID, + AuditPlanType: overview.AuditPlanType, + SqlTotalNum: overview.SqlTotalNum, + FirstScore: overview.FirstScore, + LatestScore: overview.LatestScore, + ScoreChange: overview.ScoreChange, + RemediationRate: overview.RemediationRate, + RemediationStatusCount: &RemediationStatusCounter{ + Resolved: overview.RemediationStatusCount.Resolved, + PartiallyFixed: overview.RemediationStatusCount.PartiallyFixed, + Unchanged: overview.RemediationStatusCount.Unchanged, + Deteriorated: overview.RemediationStatusCount.Deteriorated, + NewlyDiscovered: overview.RemediationStatusCount.NewlyDiscovered, + }, + FirstAuditMissingNum: overview.FirstAuditMissingNum, + }, + }) +} + type BatchUpdateSqlManageReq struct { SqlManageIdList []*uint64 `json:"sql_manage_id_list"` Status *string `json:"status" enums:"solved,ignored,manual_audited"` @@ -518,9 +632,9 @@ func GetGlobalSqlManageList(c echo.Context) error { type GetGlobalSqlManageStatisticsReq struct { FilterProjectUid *string `query:"filter_project_uid" json:"filter_project_uid,omitempty"` - FilterInstanceId *string `query:"filter_instance_id" json:"filter_instance_id,omitempty"` - FilterProjectPriority *dmsV1.ProjectPriority `query:"filter_project_priority" json:"filter_project_priority,omitempty" enums:"high,medium,low"` - FilterCurrentStepAssigneeUserId *string `query:"filter_current_step_assignee_user_id" json:"filter_current_step_assignee_user_id,omitempty"` + FilterInstanceId *string `query:"filter_instance_id" json:"filter_instance_id,omitempty"` + FilterProjectPriority *dmsV1.ProjectPriority `query:"filter_project_priority" json:"filter_project_priority,omitempty" enums:"high,medium,low"` + FilterCurrentStepAssigneeUserId *string `query:"filter_current_step_assignee_user_id" json:"filter_current_step_assignee_user_id,omitempty"` } type GetGlobalSqlManageStatisticsResp struct { diff --git a/sqle/api/controller/v1/sql_manager_ce.go b/sqle/api/controller/v1/sql_manager_ce.go index 89d91ce745..56cb96681c 100644 --- a/sqle/api/controller/v1/sql_manager_ce.go +++ b/sqle/api/controller/v1/sql_manager_ce.go @@ -20,6 +20,14 @@ func sendSqlManage(c echo.Context) error { return ErrCommunityEditionNotSupportSqlManage } +func exportSqlManageRemediationV1(c echo.Context) error { + return ErrCommunityEditionNotSupportSqlManage +} + +func exportGlobalSqlManageRemediationV1(c echo.Context) error { + return ErrCommunityEditionNotSupportSqlManage +} + func getSqlManageRuleTips(c echo.Context) error { return ErrCommunityEditionNotSupportSqlManage } diff --git a/sqle/api/controller/v2/sql_manage.go b/sqle/api/controller/v2/sql_manage.go index 06d44a9a33..726024f146 100644 --- a/sqle/api/controller/v2/sql_manage.go +++ b/sqle/api/controller/v2/sql_manage.go @@ -50,6 +50,7 @@ type SqlManage struct { // @Param filter_last_audit_start_time_to query string false "last audit start time to" // @Param filter_status query string false "status" Enums(unhandled,solved,ignored,manual_audited,sent) // @Param filter_rule_name query string false "rule name" +// @Param filter_remediation_status query string false "remediation status" Enums(resolved,partially_fixed,unchanged,deteriorated,newly_discovered) // @Param filter_db_type query string false "db type" // @Param filter_business query string false "filter by business" // This parameter is deprecated // @Param filter_by_environment_tag query string false "filter by environment tag" diff --git a/sqle/cmd/sqled/backfill_sql_manage_first_audit.go b/sqle/cmd/sqled/backfill_sql_manage_first_audit.go new file mode 100644 index 0000000000..6524b0c7ef --- /dev/null +++ b/sqle/cmd/sqled/backfill_sql_manage_first_audit.go @@ -0,0 +1,62 @@ +package main + +import ( + "fmt" + + dmsCommonAes "github.com/actiontech/dms/pkg/dms-common/pkg/aes" + "github.com/actiontech/sqle/sqle/config" + "github.com/actiontech/sqle/sqle/model" + "github.com/actiontech/sqle/sqle/utils" + "github.com/spf13/cobra" +) + +func backfillSQLManageFirstAuditCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "backfill-sql-manage-first-audit", + Short: "Backfill SQL manage first audit result from latest audit result", + RunE: func(cmd *cobra.Command, args []string) error { + storage, err := newStorageFromCommandConfig() + if err != nil { + return err + } + model.InitStorage(storage) + + affectedRows, err := storage.BackfillSQLManageFirstAuditResult() + if err != nil { + return err + } + fmt.Printf("backfilled sql_manage_records first audit result rows: %d\n", affectedRows) + return nil + }, + } + cmd.Flags().StringVarP(&configPath, "config", "", "", "config file path") + cmd.Flags().StringVarP(&mysqlUser, "mysql-user", "", "sqle", "mysql user") + cmd.Flags().StringVarP(&mysqlPass, "mysql-password", "", "sqle", "mysql password") + cmd.Flags().StringVarP(&mysqlHost, "mysql-host", "", "localhost", "mysql host") + cmd.Flags().StringVarP(&mysqlPort, "mysql-port", "", "3306", "mysql port") + cmd.Flags().StringVarP(&mysqlSchema, "mysql-schema", "", "sqle", "mysql schema") + cmd.Flags().BoolVarP(&debug, "debug", "", false, "debug mode, print more log") + return cmd +} + +func newStorageFromCommandConfig() (*model.Storage, error) { + if configPath != "" { + config.ParseConfigFile(configPath) + dbConfig := config.GetOptions().SqleOptions.Service.Database + dbPassword := dbConfig.Password + if dbConfig.SecretPassword != "" { + password, err := dmsCommonAes.AesDecrypt(dbConfig.SecretPassword) + if err != nil { + return nil, fmt.Errorf("read db info from config file error, %d", err) + } + dbPassword = password + } + return model.NewStorage(dbConfig.User, dbPassword, dbConfig.Host, dbConfig.Port, dbConfig.Schema, debug) + } + + plainPassword, err := utils.DecodeString(mysqlPass) + if err != nil { + return nil, fmt.Errorf("decode mysql password to string error : %v", err) + } + return model.NewStorage(mysqlUser, plainPassword, mysqlHost, mysqlPort, mysqlSchema, debug) +} diff --git a/sqle/cmd/sqled/sqled.go b/sqle/cmd/sqled/sqled.go index 9e294e5851..8602b96d22 100644 --- a/sqle/cmd/sqled/sqled.go +++ b/sqle/cmd/sqled/sqled.go @@ -68,6 +68,7 @@ func main() { rootCmd.Flags().StringVarP(&pluginPath, "plugin-path", "", "", "plugin path") rootCmd.AddCommand(genSecretPasswordCmd()) + rootCmd.AddCommand(backfillSQLManageFirstAuditCmd()) if err := rootCmd.Execute(); err != nil { log.NewEntry().Error("sqle abnormal termination:", err) os.Exit(1) diff --git a/sqle/model/instance_audit_plan.go b/sqle/model/instance_audit_plan.go index 158bf99c55..a0bbc3db13 100644 --- a/sqle/model/instance_audit_plan.go +++ b/sqle/model/instance_audit_plan.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/json" "fmt" + "math" "time" "github.com/actiontech/sqle/sqle/errors" @@ -36,6 +37,23 @@ const ( LastCollectionAbnormal = "abnormal" ) +const ( + RuleDiffResolved = "resolved" + RuleDiffNew = "new" + RuleDiffUnchanged = "unchanged" + + auditLevelNormal = "normal" + auditLevelNotice = "notice" + auditLevelWarn = "warn" + auditLevelError = "error" + + RemediationStatusResolved = "resolved" + RemediationStatusPartiallyFixed = "partially_fixed" + RemediationStatusUnchanged = "unchanged" + RemediationStatusDeteriorated = "deteriorated" + RemediationStatusNewlyDiscovered = "newly_discovered" +) + // TODO 推送配置 type NotifyConfig struct { // NotifyInterval int `json:"notify_interval" gorm:"default:10"` @@ -210,22 +228,228 @@ func (s *Storage) HasSQLManageRecords(sourceId string, source string) (bool, err type SQLManageRecord struct { Model - Source string `json:"source" gorm:"type:varchar(255);index:idx_source_id_source"` - SourceId string `json:"source_id" gorm:"type:varchar(255);index:idx_source_id_source"` - ProjectId string `json:"project_id" gorm:"type:varchar(255)"` - InstanceID string `json:"instance_id" gorm:"type:varchar(255)"` - SchemaName string `json:"schema_name" gorm:"type:varchar(255)"` - SqlFingerprint string `json:"sql_fingerprint" gorm:"type:mediumtext;not null"` - SqlText string `json:"sql_text" gorm:"type:mediumtext;not null"` - Info JSON `gorm:"type:json"` // 慢日志的 执行时间等特殊属性 - AuditLevel string `json:"audit_level" gorm:"type:varchar(255)"` - AuditResults *AuditResults `json:"audit_results" gorm:"type:json"` - SQLID string `json:"sql_id" gorm:"type:varchar(255);unique;not null"` - Priority sql.NullString `json:"priority" gorm:"type:varchar(255)"` + Source string `json:"source" gorm:"type:varchar(255);index:idx_source_id_source"` + SourceId string `json:"source_id" gorm:"type:varchar(255);index:idx_source_id_source"` + ProjectId string `json:"project_id" gorm:"type:varchar(255)"` + InstanceID string `json:"instance_id" gorm:"type:varchar(255)"` + SchemaName string `json:"schema_name" gorm:"type:varchar(255)"` + SqlFingerprint string `json:"sql_fingerprint" gorm:"type:mediumtext;not null"` + SqlText string `json:"sql_text" gorm:"type:mediumtext;not null"` + Info JSON `gorm:"type:json"` // 慢日志的 执行时间等特殊属性 + AuditLevel string `json:"audit_level" gorm:"type:varchar(255)"` + AuditResults *AuditResults `json:"audit_results" gorm:"type:json"` + FirstAuditResults AuditResults `json:"first_audit_results" gorm:"type:json"` + FirstAuditTime *time.Time `json:"first_audit_time" gorm:"type:datetime(3)"` + SQLID string `json:"sql_id" gorm:"type:varchar(255);unique;not null"` + Priority sql.NullString `json:"priority" gorm:"type:varchar(255)"` SQLManager SQLManageRecordProcess } +type RuleDiff struct { + Type string `json:"type"` + AuditRules AuditResults `json:"audit_rules"` +} + +type RemediationResult struct { + Resolved AuditResults `json:"resolved"` + New AuditResults `json:"new"` + Unchanged AuditResults `json:"unchanged"` + Status string `json:"status"` + FirstAuditMissing bool `json:"first_audit_missing"` +} + +type RemediationStatusCounter struct { + Resolved uint64 `json:"resolved"` + PartiallyFixed uint64 `json:"partially_fixed"` + Unchanged uint64 `json:"unchanged"` + Deteriorated uint64 `json:"deteriorated"` + NewlyDiscovered uint64 `json:"newly_discovered"` +} + +type SqlManageRemediationOverview struct { + ProjectID string `json:"project_id"` + InstanceAuditPlanID string `json:"instance_audit_plan_id"` + AuditPlanType string `json:"audit_plan_type"` + SqlTotalNum uint64 `json:"sql_total_num"` + FirstScore int32 `json:"first_score"` + LatestScore int32 `json:"latest_score"` + ScoreChange int32 `json:"score_change"` + RemediationRate float64 `json:"remediation_rate"` + RemediationStatusCount RemediationStatusCounter `json:"remediation_status_count"` + FirstAuditMissingNum uint64 `json:"first_audit_missing_num"` +} + +func CalculateRemediationResult(firstAuditResults, latestAuditResults AuditResults, firstAuditMissing bool) RemediationResult { + firstByRuleName := auditResultsByRuleName(firstAuditResults) + latestByRuleName := auditResultsByRuleName(latestAuditResults) + + result := RemediationResult{FirstAuditMissing: firstAuditMissing} + for ruleName, firstAuditResult := range firstByRuleName { + if _, ok := latestByRuleName[ruleName]; ok { + result.Unchanged = append(result.Unchanged, firstAuditResult) + continue + } + result.Resolved = append(result.Resolved, firstAuditResult) + } + for ruleName, latestAuditResult := range latestByRuleName { + if _, ok := firstByRuleName[ruleName]; ok { + continue + } + result.New = append(result.New, latestAuditResult) + } + + switch { + case len(firstByRuleName) == 0 && len(latestByRuleName) > 0: + result.Status = RemediationStatusNewlyDiscovered + case len(firstByRuleName) > 0 && len(latestByRuleName) == 0: + result.Status = RemediationStatusResolved + case len(result.New) > 0: + result.Status = RemediationStatusDeteriorated + case len(result.Resolved) > 0 && len(latestByRuleName) > 0: + result.Status = RemediationStatusPartiallyFixed + default: + result.Status = RemediationStatusUnchanged + } + + return result +} + +func dereferenceAuditResults(auditResults *AuditResults) AuditResults { + if auditResults == nil { + return nil + } + return *auditResults +} + +func auditResultsByRuleName(auditResults AuditResults) map[string]AuditResult { + result := make(map[string]AuditResult, len(auditResults)) + for _, auditResult := range auditResults { + if auditResult.RuleName == "" { + continue + } + if _, ok := result[auditResult.RuleName]; !ok { + result[auditResult.RuleName] = auditResult + } + } + return result +} + +func (o SQLManageRecord) RemediationResult() RemediationResult { + return CalculateRemediationResult(o.FirstAuditResults, dereferenceAuditResults(o.AuditResults), o.FirstAuditTime == nil) +} + +func CalculateSqlManageRemediationOverview(projectID, instanceAuditPlanID, auditPlanType string, records []*SQLManageRecord) SqlManageRemediationOverview { + overview := SqlManageRemediationOverview{ + ProjectID: projectID, + InstanceAuditPlanID: instanceAuditPlanID, + AuditPlanType: auditPlanType, + SqlTotalNum: uint64(len(records)), + } + if len(records) == 0 { + return overview + } + + firstAuditResults := make([]AuditResults, 0, len(records)) + latestAuditResults := make([]AuditResults, 0, len(records)) + for _, record := range records { + if record == nil { + continue + } + remediationResult := record.RemediationResult() + switch remediationResult.Status { + case RemediationStatusResolved: + overview.RemediationStatusCount.Resolved++ + case RemediationStatusPartiallyFixed: + overview.RemediationStatusCount.PartiallyFixed++ + case RemediationStatusDeteriorated: + overview.RemediationStatusCount.Deteriorated++ + case RemediationStatusNewlyDiscovered: + overview.RemediationStatusCount.NewlyDiscovered++ + default: + overview.RemediationStatusCount.Unchanged++ + } + if remediationResult.FirstAuditMissing { + overview.FirstAuditMissingNum++ + } + firstAuditResults = append(firstAuditResults, record.FirstAuditResults) + latestAuditResults = append(latestAuditResults, dereferenceAuditResults(record.AuditResults)) + } + + overview.FirstScore = CalculateAuditResultsScore(firstAuditResults) + overview.LatestScore = CalculateAuditResultsScore(latestAuditResults) + overview.ScoreChange = overview.LatestScore - overview.FirstScore + remediated := overview.RemediationStatusCount.Resolved + overview.RemediationStatusCount.PartiallyFixed + overview.RemediationRate = float64(remediated) / float64(len(records)) + return overview +} + +func CalculateAuditResultsScore(auditResultsList []AuditResults) int32 { + if len(auditResultsList) == 0 { + return 0 + } + + var errorCount, warnCount, noticeCount float64 + for _, auditResults := range auditResultsList { + switch auditLevelFromResults(auditResults) { + case auditLevelError: + errorCount++ + case auditLevelWarn: + warnCount++ + case auditLevelNotice: + noticeCount++ + } + } + + numberOfTask := float64(len(auditResultsList)) + errorRate := errorCount / numberOfTask + warnRate := (warnCount + errorCount) / numberOfTask + noticeRate := (noticeCount + warnCount + errorCount) / numberOfTask + passRate := (numberOfTask - noticeCount - warnCount - errorCount) / numberOfTask + + totalScore := passRate * 30 + totalScore += (1 - errorRate) * 15 + totalScore += (1 - warnRate) * 10 + totalScore += (1 - noticeRate) * 5 + if errorRate == 0 { + totalScore += 15 + } + if warnRate == 0 { + totalScore += 10 + } + if noticeRate == 0 { + totalScore += 5 + } + if errorRate < 0.1 { + totalScore += 5 + } + if warnRate < 0.1 { + totalScore += 3 + } + if noticeRate < 0.1 { + totalScore += 2 + } + + return int32(math.Floor(totalScore)) +} + +func auditLevelFromResults(auditResults AuditResults) string { + level := auditLevelNormal + for _, auditResult := range auditResults { + switch auditResult.Level { + case auditLevelError: + return auditLevelError + case auditLevelWarn: + level = auditLevelWarn + case auditLevelNotice: + if level == auditLevelNormal { + level = auditLevelNotice + } + } + } + return level +} + func (o SQLManageRecord) GetFingerprintMD5() string { if o.SQLID != "" { return o.SQLID @@ -379,6 +603,29 @@ func (s *Storage) GetManagerSQLListByAuditPlanId(auditPlanID uint) ([]*SQLManage return sqls, nil } +func (s *Storage) GetManagerSQLListByInstanceAuditPlanAndType(instanceAuditPlanID, auditPlanType string) ([]*SQLManageRecord, error) { + sqls := []*SQLManageRecord{} + err := s.db.Where("source_id = ? AND source = ?", instanceAuditPlanID, auditPlanType).Find(&sqls).Error + if err != nil { + return nil, err + } + return sqls, nil +} + +func (s *Storage) BackfillSQLManageFirstAuditResult() (int64, error) { + result := s.db.Model(&SQLManageRecord{}). + Where("(first_audit_results IS NULL OR first_audit_results = '' OR first_audit_results = 'null')"). + Where("audit_results IS NOT NULL AND audit_results <> '' AND audit_results <> 'null'"). + Updates(map[string]interface{}{ + "first_audit_results": gorm.Expr("audit_results"), + "first_audit_time": gorm.Expr("last_receive_timestamp"), + }) + if result.Error != nil { + return 0, result.Error + } + return result.RowsAffected, nil +} + // 获取指定扫描任务下的所有Schema func (s *Storage) GetManagerSqlSchemaNameByAuditPlan(auditPlanId uint) ([]string, error) { var metricValueTips []string diff --git a/sqle/model/sql_manage_remediation_test.go b/sqle/model/sql_manage_remediation_test.go new file mode 100644 index 0000000000..2a61959456 --- /dev/null +++ b/sqle/model/sql_manage_remediation_test.go @@ -0,0 +1,115 @@ +package model + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCalculateRemediationResultRuleDiff(t *testing.T) { + result := CalculateRemediationResult( + AuditResults{ + {RuleName: "rule_resolved", Level: "warn"}, + {RuleName: "rule_unchanged", Level: "warn"}, + }, + AuditResults{ + {RuleName: "rule_unchanged", Level: "error"}, + {RuleName: "rule_new", Level: "warn"}, + }, + false, + ) + + assert.Equal(t, RemediationStatusDeteriorated, result.Status) + assert.Equal(t, []string{"rule_resolved"}, auditResultRuleNames(result.Resolved)) + assert.Equal(t, []string{"rule_new"}, auditResultRuleNames(result.New)) + assert.Equal(t, []string{"rule_unchanged"}, auditResultRuleNames(result.Unchanged)) +} + +func TestCalculateRemediationResultStatus(t *testing.T) { + testCases := []struct { + name string + first AuditResults + latest AuditResults + status string + }{ + {name: "resolved", first: auditResults("rule_a"), latest: nil, status: RemediationStatusResolved}, + {name: "partially fixed", first: auditResults("rule_a", "rule_b"), latest: auditResults("rule_b"), status: RemediationStatusPartiallyFixed}, + {name: "unchanged", first: auditResults("rule_a"), latest: auditResults("rule_a"), status: RemediationStatusUnchanged}, + {name: "deteriorated", first: auditResults("rule_a"), latest: auditResults("rule_a", "rule_b"), status: RemediationStatusDeteriorated}, + {name: "newly discovered", first: nil, latest: auditResults("rule_a"), status: RemediationStatusNewlyDiscovered}, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + result := CalculateRemediationResult(testCase.first, testCase.latest, false) + assert.Equal(t, testCase.status, result.Status) + }) + } +} + +func TestCalculateRemediationResultFirstAuditMissing(t *testing.T) { + result := CalculateRemediationResult(nil, auditResults("rule_a"), true) + + assert.True(t, result.FirstAuditMissing) + assert.Equal(t, RemediationStatusNewlyDiscovered, result.Status) + assert.Equal(t, []string{"rule_a"}, auditResultRuleNames(result.New)) +} + +func TestAuditResultsScanNull(t *testing.T) { + var results AuditResults + + assert.NoError(t, results.Scan(nil)) + assert.Nil(t, results) + assert.NoError(t, results.Scan([]byte("null"))) + assert.Nil(t, results) +} + +func TestCalculateAuditResultsScore(t *testing.T) { + assert.Equal(t, int32(100), CalculateAuditResultsScore([]AuditResults{nil, {}})) + assert.Equal(t, int32(0), CalculateAuditResultsScore([]AuditResults{auditResultsWithLevel(auditLevelError)})) + assert.Greater(t, CalculateAuditResultsScore([]AuditResults{auditResultsWithLevel(auditLevelWarn)}), CalculateAuditResultsScore([]AuditResults{auditResultsWithLevel(auditLevelError)})) +} + +func TestCalculateSqlManageRemediationOverview(t *testing.T) { + records := []*SQLManageRecord{ + {FirstAuditResults: auditResults("rule_a"), AuditResults: nil}, + {FirstAuditResults: auditResults("rule_b", "rule_c"), AuditResults: auditResultsPtr("rule_c")}, + {FirstAuditResults: auditResults("rule_d"), AuditResults: auditResultsPtr("rule_d", "rule_e")}, + } + + overview := CalculateSqlManageRemediationOverview("project-id", "plan-id", "default", records) + + assert.Equal(t, "project-id", overview.ProjectID) + assert.Equal(t, "plan-id", overview.InstanceAuditPlanID) + assert.Equal(t, "default", overview.AuditPlanType) + assert.Equal(t, uint64(3), overview.SqlTotalNum) + assert.Equal(t, uint64(1), overview.RemediationStatusCount.Resolved) + assert.Equal(t, uint64(1), overview.RemediationStatusCount.PartiallyFixed) + assert.Equal(t, uint64(1), overview.RemediationStatusCount.Deteriorated) + assert.InDelta(t, 2.0/3.0, overview.RemediationRate, 0.001) +} + +func auditResultsPtr(ruleNames ...string) *AuditResults { + results := auditResults(ruleNames...) + return &results +} + +func auditResults(ruleNames ...string) AuditResults { + results := make(AuditResults, 0, len(ruleNames)) + for _, ruleName := range ruleNames { + results = append(results, AuditResult{RuleName: ruleName, Level: auditLevelWarn}) + } + return results +} + +func auditResultsWithLevel(level string) AuditResults { + return AuditResults{{RuleName: level + "_rule", Level: level}} +} + +func auditResultRuleNames(auditResults AuditResults) []string { + ruleNames := make([]string, 0, len(auditResults)) + for _, auditResult := range auditResults { + ruleNames = append(ruleNames, auditResult.RuleName) + } + return ruleNames +}