diff --git a/internal/diff/column.go b/internal/diff/column.go index fdacac45..bc197e98 100644 --- a/internal/diff/column.go +++ b/internal/diff/column.go @@ -2,6 +2,7 @@ package diff import ( "fmt" + "strings" "github.com/pgschema/pgschema/ir" ) @@ -74,22 +75,32 @@ func (cd *ColumnDiff) generateColumnSQL(tableSchema, tableName string, targetSch } } else { // Normal default value change handling (no USING clause involved) - if (oldDefault == nil) != (newDefault == nil) || - (oldDefault != nil && newDefault != nil && *oldDefault != *newDefault) { - - var sql string - if newDefault == nil { - sql = fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT;", - qualifiedTableName, ir.QuoteIdentifier(cd.New.Name)) - } else { - sql = fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s;", - qualifiedTableName, ir.QuoteIdentifier(cd.New.Name), *newDefault) - } - + // We only drop default values when they are not sequences + // Sequences are automatically handled by the DROP CASCADE statement + if oldDefault != nil && newDefault == nil && !strings.HasPrefix(*oldDefault, "nextval(") { + sql := fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT;", + qualifiedTableName, ir.QuoteIdentifier(cd.New.Name)) + statements = append(statements, sql) + } + if (oldDefault == nil && newDefault != nil) || (oldDefault != nil && newDefault != nil && *oldDefault != *newDefault) { + sql := fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s;", + qualifiedTableName, ir.QuoteIdentifier(cd.New.Name), *newDefault) statements = append(statements, sql) } } + // Handle identity column changes + if cd.Old.Identity != nil && (cd.New.Identity == nil || cd.Old.Identity.Generation != cd.New.Identity.Generation) { + sql := fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s DROP IDENTITY;", + qualifiedTableName, ir.QuoteIdentifier(cd.New.Name)) + statements = append(statements, sql) + } + if cd.New.Identity != nil && (cd.Old.Identity == nil || cd.Old.Identity.Generation != cd.New.Identity.Generation) { + sql := fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s ADD GENERATED %s AS IDENTITY;", + qualifiedTableName, ir.QuoteIdentifier(cd.New.Name), cd.New.Identity.Generation) + statements = append(statements, sql) + } + return statements } diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 01fd298e..06006dcb 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -276,12 +276,12 @@ type ddlDiff struct { droppedDefaultPrivileges []*ir.DefaultPrivilege modifiedDefaultPrivileges []*defaultPrivilegeDiff // Explicit object privileges - addedPrivileges []*ir.Privilege - droppedPrivileges []*ir.Privilege - modifiedPrivileges []*privilegeDiff - revokedDefaultGrantsOnNewTables []*ir.Privilege // Privileges to revoke on newly created tables (issue #253) - addedRevokedDefaultPrivs []*ir.RevokedDefaultPrivilege - droppedRevokedDefaultPrivs []*ir.RevokedDefaultPrivilege + addedPrivileges []*ir.Privilege + droppedPrivileges []*ir.Privilege + modifiedPrivileges []*privilegeDiff + revokedDefaultGrantsOnNewTables []*ir.Privilege // Privileges to revoke on newly created tables (issue #253) + addedRevokedDefaultPrivs []*ir.RevokedDefaultPrivilege + droppedRevokedDefaultPrivs []*ir.RevokedDefaultPrivilege // Column-level privileges addedColumnPrivileges []*ir.ColumnPrivilege droppedColumnPrivileges []*ir.ColumnPrivilege @@ -915,8 +915,10 @@ func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string) []Diff { for _, key := range seqKeys { seq := newSequences[key] if _, exists := oldSequences[key]; !exists { - // Skip sequences owned by table columns (created by SERIAL) - if seq.OwnedByTable != "" && seq.OwnedByColumn != "" { + // Skip sequences owned by table columns only if the column is also new + // (created by SERIAL in CREATE TABLE). If the column already exists, + // we need to create the sequence explicitly for ALTER COLUMN to use. + if seq.OwnedByTable != "" && seq.OwnedByColumn != "" && !columnExistsInTables(oldTables, seq.Schema, seq.OwnedByTable, seq.OwnedByColumn) { continue } diff.addedSequences = append(diff.addedSequences, seq) @@ -929,7 +931,7 @@ func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string) []Diff { seq := oldSequences[key] if _, exists := newSequences[key]; !exists { // Skip sequences owned by table columns (created by SERIAL) - if seq.OwnedByTable != "" && seq.OwnedByColumn != "" { + if seq.OwnedByTable != "" && seq.OwnedByColumn != "" && !columnExistsInTables(newTables, seq.Schema, seq.OwnedByTable, seq.OwnedByColumn) { continue } diff.droppedSequences = append(diff.droppedSequences, seq) @@ -1797,6 +1799,19 @@ func sortedKeys[T any](m map[string]T) []string { return keys } +// columnExistsInTables checks if a column exists in the given tables map +func columnExistsInTables(tables map[string]*ir.Table, schema, tableName, columnName string) bool { + tableKey := schema + "." + tableName + if table, exists := tables[tableKey]; exists { + for _, col := range table.Columns { + if col.Name == columnName { + return true + } + } + } + return false +} + // buildFunctionLookup returns case-insensitive lookup keys for newly added functions. // Keys include both unqualified (function name only) and schema-qualified identifiers. func buildFunctionLookup(functions []*ir.Function) map[string]struct{} { diff --git a/internal/diff/sequence.go b/internal/diff/sequence.go index 2526c56b..588d1c06 100644 --- a/internal/diff/sequence.go +++ b/internal/diff/sequence.go @@ -11,9 +11,9 @@ import ( // Default values for PostgreSQL sequences by data type const ( defaultSequenceMinValue int64 = 1 - defaultSequenceMaxValue int64 = math.MaxInt64 // bigint max - smallintMaxValue int64 = math.MaxInt16 // smallint max - integerMaxValue int64 = math.MaxInt32 // integer max + defaultSequenceMaxValue int64 = math.MaxInt64 // bigint max + smallintMaxValue int64 = math.MaxInt16 // smallint max + integerMaxValue int64 = math.MaxInt32 // integer max ) // generateCreateSequencesSQL generates CREATE SEQUENCE statements @@ -113,6 +113,11 @@ func generateSequenceSQL(seq *ir.Sequence, targetSchema string) string { parts = append(parts, "CYCLE") } + // Add sequence owner + if seq.OwnedByTable != "" && seq.OwnedByColumn != "" { + parts = append(parts, fmt.Sprintf("OWNED BY %s.%s", seq.OwnedByTable, seq.OwnedByColumn)) + } + // Join with proper formatting if len(parts) > 1 { return parts[0] + " " + strings.Join(parts[1:], " ") + ";" @@ -201,7 +206,7 @@ func sequencesEqual(old, new *ir.Sequence) bool { if old.Name != new.Name { return false } - + // Compare DataType (default is bigint if empty) oldDataType := old.DataType if oldDataType == "" { @@ -214,7 +219,7 @@ func sequencesEqual(old, new *ir.Sequence) bool { if oldDataType != newDataType { return false } - + if old.StartValue != new.StartValue { return false } diff --git a/internal/plan/rewrite.go b/internal/plan/rewrite.go index 8498261e..1239e11c 100644 --- a/internal/plan/rewrite.go +++ b/internal/plan/rewrite.go @@ -89,6 +89,16 @@ func generateRewrite(d diff.Diff, newlyCreatedTables map[string]bool, newlyCreat } } } + // Check if identity is being added or changed on an existing column + // This includes: adding identity, or changing identity generation (drop + re-add) + if columnDiff.New.Identity != nil { + // Verify this diff's SQL actually contains ADD GENERATED + for _, stmt := range d.Statements { + if strings.Contains(stmt.SQL, "ADD GENERATED") { + return generateColumnIdentityRewrite(columnDiff, d.Path) + } + } + } } } } @@ -324,6 +334,40 @@ func generateColumnNotNullRewrite(_ *diff.ColumnDiff, path string) []RewriteStep } } +// generateColumnIdentityRewrite generates rewrite steps for ADD GENERATED AS IDENTITY operations +// It syncs the identity sequence with existing data to prevent conflicts +func generateColumnIdentityRewrite(columnDiff *diff.ColumnDiff, path string) []RewriteStep { + // Parse path (schema.table.column) to extract schema, table, and column names + parts := strings.Split(path, ".") + if len(parts) != 3 { + return nil + } + schema := parts[0] + table := parts[1] + column := parts[2] + + tableName := getTableNameWithSchema(schema, table) + + // Step 1: Add identity column + addIdentitySQL := fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s ADD GENERATED %s AS IDENTITY;", + tableName, ir.QuoteIdentifier(column), columnDiff.New.Identity.Generation) + + // Step 2: Sync sequence with existing data + setvalSQL := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), COALESCE(MAX(%s), 0) + 1) FROM %s;", + tableName, column, ir.QuoteIdentifier(column), tableName) + + return []RewriteStep{ + { + SQL: addIdentitySQL, + CanRunInTransaction: true, + }, + { + SQL: setvalSQL, + CanRunInTransaction: true, + }, + } +} + // generateIndexSQL generates CREATE INDEX statement func generateIndexSQL(index *ir.Index, isConcurrent bool) string { var sql strings.Builder diff --git a/testdata/diff/create_table/alter_identity/diff.sql b/testdata/diff/create_table/alter_identity/diff.sql new file mode 100644 index 00000000..3734e7d2 --- /dev/null +++ b/testdata/diff/create_table/alter_identity/diff.sql @@ -0,0 +1,6 @@ +DROP SEQUENCE IF EXISTS table1_c2_seq CASCADE; +CREATE SEQUENCE IF NOT EXISTS table1_c1_seq AS integer OWNED BY table1.c1; +ALTER TABLE table1 ALTER COLUMN c1 SET DEFAULT nextval('table1_c1_seq'::regclass); +ALTER TABLE table1 ALTER COLUMN c2 ADD GENERATED ALWAYS AS IDENTITY; +ALTER TABLE table1 ALTER COLUMN c3 DROP IDENTITY; +ALTER TABLE table1 ALTER COLUMN c3 ADD GENERATED BY DEFAULT AS IDENTITY; diff --git a/testdata/diff/create_table/alter_identity/new.sql b/testdata/diff/create_table/alter_identity/new.sql new file mode 100644 index 00000000..a5d2330d --- /dev/null +++ b/testdata/diff/create_table/alter_identity/new.sql @@ -0,0 +1,5 @@ +CREATE TABLE public.table1 ( + c1 serial NOT NULL, + c2 int GENERATED ALWAYS AS IDENTITY, + c3 int GENERATED BY DEFAULT AS IDENTITY +); diff --git a/testdata/diff/create_table/alter_identity/old.sql b/testdata/diff/create_table/alter_identity/old.sql new file mode 100644 index 00000000..a38eb3c2 --- /dev/null +++ b/testdata/diff/create_table/alter_identity/old.sql @@ -0,0 +1,5 @@ +CREATE TABLE public.table1 ( + c1 int NOT NULL, + c2 serial, + c3 int GENERATED ALWAYS AS IDENTITY +); diff --git a/testdata/diff/create_table/alter_identity/plan.json b/testdata/diff/create_table/alter_identity/plan.json new file mode 100644 index 00000000..49f71839 --- /dev/null +++ b/testdata/diff/create_table/alter_identity/plan.json @@ -0,0 +1,62 @@ +{ + "version": "1.0.0", + "pgschema_version": "1.6.2", + "created_at": "1970-01-01T00:00:00Z", + "source_fingerprint": { + "hash": "5b91475214f7a1b4e4928c9480533b61f841d70494784aff431f1f392fba1e58" + }, + "groups": [ + { + "steps": [ + { + "sql": "DROP SEQUENCE IF EXISTS table1_c2_seq CASCADE;", + "type": "sequence", + "operation": "drop", + "path": "public.table1_c2_seq" + }, + { + "sql": "CREATE SEQUENCE IF NOT EXISTS table1_c1_seq AS integer OWNED BY table1.c1;", + "type": "sequence", + "operation": "create", + "path": "public.table1_c1_seq" + }, + { + "sql": "ALTER TABLE table1 ALTER COLUMN c1 SET DEFAULT nextval('table1_c1_seq'::regclass);", + "type": "table.column", + "operation": "alter", + "path": "public.table1.c1" + }, + { + "sql": "ALTER TABLE table1 ALTER COLUMN c2 ADD GENERATED ALWAYS AS IDENTITY;", + "type": "table.column", + "operation": "alter", + "path": "public.table1.c2" + }, + { + "sql": "SELECT setval(pg_get_serial_sequence('table1', 'c2'), COALESCE(MAX(c2), 0) + 1) FROM table1;", + "type": "table.column", + "operation": "alter", + "path": "public.table1.c2" + }, + { + "sql": "ALTER TABLE table1 ALTER COLUMN c3 DROP IDENTITY;", + "type": "table.column", + "operation": "alter", + "path": "public.table1.c3" + }, + { + "sql": "ALTER TABLE table1 ALTER COLUMN c3 ADD GENERATED BY DEFAULT AS IDENTITY;", + "type": "table.column", + "operation": "alter", + "path": "public.table1.c3" + }, + { + "sql": "SELECT setval(pg_get_serial_sequence('table1', 'c3'), COALESCE(MAX(c3), 0) + 1) FROM table1;", + "type": "table.column", + "operation": "alter", + "path": "public.table1.c3" + } + ] + } + ] +} diff --git a/testdata/diff/create_table/alter_identity/plan.sql b/testdata/diff/create_table/alter_identity/plan.sql new file mode 100644 index 00000000..d5031a70 --- /dev/null +++ b/testdata/diff/create_table/alter_identity/plan.sql @@ -0,0 +1,15 @@ +DROP SEQUENCE IF EXISTS table1_c2_seq CASCADE; + +CREATE SEQUENCE IF NOT EXISTS table1_c1_seq AS integer OWNED BY table1.c1; + +ALTER TABLE table1 ALTER COLUMN c1 SET DEFAULT nextval('table1_c1_seq'::regclass); + +ALTER TABLE table1 ALTER COLUMN c2 ADD GENERATED ALWAYS AS IDENTITY; + +SELECT setval(pg_get_serial_sequence('table1', 'c2'), COALESCE(MAX(c2), 0) + 1) FROM table1; + +ALTER TABLE table1 ALTER COLUMN c3 DROP IDENTITY; + +ALTER TABLE table1 ALTER COLUMN c3 ADD GENERATED BY DEFAULT AS IDENTITY; + +SELECT setval(pg_get_serial_sequence('table1', 'c3'), COALESCE(MAX(c3), 0) + 1) FROM table1; diff --git a/testdata/diff/create_table/alter_identity/plan.txt b/testdata/diff/create_table/alter_identity/plan.txt new file mode 100644 index 00000000..531a2bd1 --- /dev/null +++ b/testdata/diff/create_table/alter_identity/plan.txt @@ -0,0 +1,34 @@ +Plan: 1 to add, 1 to modify, 1 to drop. + +Summary by type: + sequences: 1 to add, 1 to drop + tables: 1 to modify + +Sequences: + + table1_c1_seq + - table1_c2_seq + +Tables: + ~ table1 + ~ c1 (column) + ~ c2 (column) + ~ c3 (column) + +DDL to be executed: +-------------------------------------------------- + +DROP SEQUENCE IF EXISTS table1_c2_seq CASCADE; + +CREATE SEQUENCE IF NOT EXISTS table1_c1_seq AS integer OWNED BY table1.c1; + +ALTER TABLE table1 ALTER COLUMN c1 SET DEFAULT nextval('table1_c1_seq'::regclass); + +ALTER TABLE table1 ALTER COLUMN c2 ADD GENERATED ALWAYS AS IDENTITY; + +SELECT setval(pg_get_serial_sequence('table1', 'c2'), COALESCE(MAX(c2), 0) + 1) FROM table1; + +ALTER TABLE table1 ALTER COLUMN c3 DROP IDENTITY; + +ALTER TABLE table1 ALTER COLUMN c3 ADD GENERATED BY DEFAULT AS IDENTITY; + +SELECT setval(pg_get_serial_sequence('table1', 'c3'), COALESCE(MAX(c3), 0) + 1) FROM table1;