diff --git a/builder.go b/builder.go new file mode 100644 index 0000000..7b86bda --- /dev/null +++ b/builder.go @@ -0,0 +1,189 @@ +package taskgraph + +import "fmt" + +// TaskBuilder helps construct taskgraph Tasks with a fluent API. +type TaskBuilder[T any] struct { + name string + resultKey Key[T] + depends []any + fn any + condition Condition + defaultVal T + defaultSet bool + defaultBindings []Binding +} + +// NewTaskBuilder creates a new builder for a task that produces a result of type T. +func NewTaskBuilder[T any](name string, key Key[T]) *TaskBuilder[T] { + return &TaskBuilder[T]{ + name: name, + resultKey: key, + } +} + +// DependsOn adds dependencies to the task. +func (b *TaskBuilder[T]) DependsOn(deps ...any) *TaskBuilder[T] { + b.depends = append(b.depends, deps...) + return b +} + +// Run sets the function to execute. The function signature must match the dependencies. +func (b *TaskBuilder[T]) Run(fn any) *TaskBuilder[T] { + b.fn = fn + return b +} + +// RunIf sets a condition for the task execution. +func (b *TaskBuilder[T]) RunIf(cond Condition) *TaskBuilder[T] { + b.condition = cond + return b +} + +// RunIfAll sets a ConditionAnd (logical AND) for the task execution using the provided keys. +func (b *TaskBuilder[T]) RunIfAll(keys ...ReadOnlyKey[bool]) *TaskBuilder[T] { + b.condition = ConditionAnd(keys) + return b +} + +// RunIfAny sets a ConditionOr (logical OR) for the task execution using the provided keys. +func (b *TaskBuilder[T]) RunIfAny(keys ...ReadOnlyKey[bool]) *TaskBuilder[T] { + b.condition = ConditionOr(keys) + return b +} + +// Default sets the default value for the result key if the condition is false. +func (b *TaskBuilder[T]) Default(val T) *TaskBuilder[T] { + b.defaultVal = val + b.defaultSet = true + return b +} + +// WithDefaultBindings adds arbitrary default bindings if the condition is false. +func (b *TaskBuilder[T]) WithDefaultBindings(bindings ...Binding) *TaskBuilder[T] { + b.defaultBindings = append(b.defaultBindings, bindings...) + return b +} + +// Build constructs and returns the Task. +func (b *TaskBuilder[T]) Build() TaskSet { + reflect := Reflect[T]{ + Name: b.name, + ResultKey: b.resultKey, + Depends: b.depends, + Fn: b.fn, + } + reflect.location = getLocation(2) + var task TaskSet = reflect + + if b.condition != nil { + conditional := Conditional{ + Wrapped: task, + Condition: b.condition, + } + + if b.defaultSet { + conditional.DefaultBindings = append(conditional.DefaultBindings, b.resultKey.Bind(b.defaultVal)) + } + conditional.DefaultBindings = append(conditional.DefaultBindings, b.defaultBindings...) + + conditional.location = getLocation(2) + task = conditional + } + + return task +} + +// MultiTaskBuilder helps construct taskgraph Tasks that provide multiple outputs or perform side effects. +type MultiTaskBuilder struct { + name string + depends []any + fn any + provides []ID + condition Condition + defaultBindings []Binding +} + +// NewMultiTaskBuilder creates a new builder for a multi-output or side-effect task. +func NewMultiTaskBuilder(name string) *MultiTaskBuilder { + return &MultiTaskBuilder{ + name: name, + } +} + +// DependsOn adds dependencies to the task. +func (b *MultiTaskBuilder) DependsOn(deps ...any) *MultiTaskBuilder { + b.depends = append(b.depends, deps...) + return b +} + +// Provides declares the keys that this task provides. +func (b *MultiTaskBuilder) Provides(keys ...any) *MultiTaskBuilder { + for _, k := range keys { + rk, err := newReflectKey(k) + if err != nil { + panic(fmt.Errorf("invalid key passed to Provides: %w", err)) + } + id, err := rk.ID() + if err != nil { + panic(fmt.Errorf("invalid key ID in Provides: %w", err)) + } + b.provides = append(b.provides, id) + } + return b +} + +// Run sets the function to execute. The function signature must match the dependencies. +// Fn must return []Binding or ([]Binding, error). +func (b *MultiTaskBuilder) Run(fn any) *MultiTaskBuilder { + b.fn = fn + return b +} + +// RunIf sets a condition for the task execution. +func (b *MultiTaskBuilder) RunIf(cond Condition) *MultiTaskBuilder { + b.condition = cond + return b +} + +// RunIfAll sets a ConditionAnd (logical AND) for the task execution using the provided keys. +func (b *MultiTaskBuilder) RunIfAll(keys ...ReadOnlyKey[bool]) *MultiTaskBuilder { + b.condition = ConditionAnd(keys) + return b +} + +// RunIfAny sets a ConditionOr (logical OR) for the task execution using the provided keys. +func (b *MultiTaskBuilder) RunIfAny(keys ...ReadOnlyKey[bool]) *MultiTaskBuilder { + b.condition = ConditionOr(keys) + return b +} + +// WithDefaultBindings adds arbitrary default bindings if the condition is false. +func (b *MultiTaskBuilder) WithDefaultBindings(bindings ...Binding) *MultiTaskBuilder { + b.defaultBindings = append(b.defaultBindings, bindings...) + return b +} + +// Build constructs and returns the Task. +func (b *MultiTaskBuilder) Build() TaskSet { + reflect := ReflectMulti{ + Name: b.name, + Depends: b.depends, + Fn: b.fn, + Provides: b.provides, + } + reflect.location = getLocation(2) + var task TaskSet = reflect + + if b.condition != nil { + conditional := Conditional{ + Wrapped: task, + Condition: b.condition, + DefaultBindings: b.defaultBindings, + } + conditional.location = getLocation(2) + task = conditional + } + + return task +} \ No newline at end of file diff --git a/builder_test.go b/builder_test.go new file mode 100644 index 0000000..a88f9bb --- /dev/null +++ b/builder_test.go @@ -0,0 +1,68 @@ +package taskgraph + +import ( + "testing" +) + +func TestTaskBuilder_RunIfAll(t *testing.T) { + k1 := NewKey[bool]("k1") + k2 := NewKey[bool]("k2") + res := NewKey[string]("res") + + task := NewTaskBuilder[string]("test", res). + Run(func() string { return "ok" }). + RunIfAll(k1, k2). + Default("default"). + Build() + + // Simulate execution (simplified verification) + tasks := task.Tasks() + if len(tasks) != 1 { + t.Fatalf("expected 1 task, got %d", len(tasks)) + } + // We can't easily execute it without a full graph, but we can check if it didn't panic and produced a task. +} + +func TestMultiTaskBuilder_Provides(t *testing.T) { + k1 := NewKey[string]("k1") + k2 := NewKey[int]("k2") + + task := NewMultiTaskBuilder("multi"). + Provides(k1, k2). + Run(func() ([]Binding, error) { + return []Binding{k1.Bind("s"), k2.Bind(1)}, nil + }). + Build() + + tasks := task.Tasks() + if len(tasks) != 1 { + t.Fatalf("expected 1 task, got %d", len(tasks)) + } + provided := tasks[0].Provides() + if len(provided) != 2 { + t.Fatalf("expected 2 provided keys, got %d", len(provided)) + } +} + +func TestMultiTaskBuilder_Provides_InvalidKey(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic on invalid key") + } + }() + NewMultiTaskBuilder("fail").Provides("not a key") +} + +func TestMultiTaskBuilder_RunIfAny(t *testing.T) { + k1 := NewKey[bool]("k1") + k2 := NewKey[bool]("k2") + + task := NewMultiTaskBuilder("multi_cond"). + RunIfAny(k1, k2). + Run(func() []Binding { return nil }). + Build() + + if task == nil { + t.Fatal("expected task to be built") + } +} diff --git a/key.go b/key.go index 9a8146e..5e9e7d4 100644 --- a/key.go +++ b/key.go @@ -100,7 +100,7 @@ func (k *key[T]) Get(b Binder) (T, error) { func NewKey[T any](id string) Key[T] { return &key[T]{ id: newID("", id), - location: getLocation(), + location: getLocation(2), } } @@ -109,7 +109,7 @@ func NewKey[T any](id string) Key[T] { func NewNamespacedKey[T any](namespace, id string) Key[T] { return &key[T]{ id: newID(namespace, id), - location: getLocation(), + location: getLocation(2), } } @@ -130,7 +130,7 @@ func (k *presenceKey[T]) Get(b Binder) (bool, error) { func Presence[T any](key ReadOnlyKey[T]) ReadOnlyKey[bool] { return &presenceKey[T]{ ReadOnlyKey: key, - location: getLocation(), + location: getLocation(2), } } @@ -159,7 +159,7 @@ func Mapped[In, Out any](key ReadOnlyKey[In], fn func(In) Out) ReadOnlyKey[Out] return &mappedKey[In, Out]{ ReadOnlyKey: key, fn: fn, - location: getLocation(), + location: getLocation(2), } } @@ -188,6 +188,6 @@ func (k *optionalKey[T]) Get(b Binder) (Maybe[T], error) { func Optional[T any](base ReadOnlyKey[T]) ReadOnlyKey[Maybe[T]] { return &optionalKey[T]{ ReadOnlyKey: base, - location: getLocation(), + location: getLocation(2), } } diff --git a/reflect.go b/reflect.go index f565064..84ee36e 100644 --- a/reflect.go +++ b/reflect.go @@ -258,7 +258,7 @@ type Reflect[T any] struct { // Locate annotates the Reflect with its location in the source code, to make error messages // easier to understand. Calling it is recommended but not required if wrapped in a Conditional func (r Reflect[T]) Locate() Reflect[T] { - r.location = getLocation() + r.location = getLocation(2) return r } @@ -337,7 +337,7 @@ type ReflectMulti struct { // Locate annotates the ReflectMulti with its location in the source code, to make error messages // easier to understand. Calling it is recommended but not required if wrapped in a Conditional func (r ReflectMulti) Locate() ReflectMulti { - r.location = getLocation() + r.location = getLocation(2) return r } diff --git a/task.go b/task.go index 7ef3502..9d5b2ec 100644 --- a/task.go +++ b/task.go @@ -91,7 +91,7 @@ func NewTask( depends: depends, provides: provides, fn: fn, - location: getLocation(), + location: getLocation(2), } } @@ -103,7 +103,7 @@ func NoOutputTask(name string, fn func(ctx context.Context, b Binder) error, dep fn: func(ctx context.Context, b Binder) ([]Binding, error) { return nil, fn(ctx, b) }, - location: getLocation(), + location: getLocation(2), } } @@ -125,7 +125,7 @@ func SimpleTask[T any]( } return []Binding{key.Bind(val)}, nil }, - location: getLocation(), + location: getLocation(2), } } @@ -151,7 +151,7 @@ func SimpleTask1[A1, Res any]( } return []Binding{resKey.Bind(res)}, nil }, - location: getLocation(), + location: getLocation(2), } } @@ -182,7 +182,7 @@ func SimpleTask2[A1, A2, Res any]( } return []Binding{resKey.Bind(res)}, nil }, - location: getLocation(), + location: getLocation(2), } } @@ -266,7 +266,7 @@ type Conditional struct { // Locate annotates the Conditional with its location in the source code, to make error messages // easier to understand. Calling it is required. func (c Conditional) Locate() Conditional { - c.location = getLocation() + c.location = getLocation(2) return c } @@ -321,6 +321,6 @@ func AllBound(name string, result Key[bool], deps ...ID) Task { fn: func(_ context.Context, _ Binder) ([]Binding, error) { return []Binding{result.Bind(true)}, nil }, - location: getLocation(), + location: getLocation(2), } } diff --git a/util.go b/util.go index 8592a6b..60face8 100644 --- a/util.go +++ b/util.go @@ -135,9 +135,9 @@ func MissingMaybe(maybes map[string]MaybeStatus) []string { return out } -func getLocation() string { - // Skip 1 for this function, and 1 for the constructor calling this. - if _, file, line, ok := runtime.Caller(2); ok { +func getLocation(skip int) string { + // Skip the requested number of stack frames. + if _, file, line, ok := runtime.Caller(skip); ok { return fmt.Sprintf("%s:%d", file, line) } return ""