diff --git a/mysql/query.go b/mysql/query.go index af00c1b..8217cae 100644 --- a/mysql/query.go +++ b/mysql/query.go @@ -19,7 +19,7 @@ func (b MySQLBackend) QueryEvents(ctx context.Context, filter nostr.Filter) (ch return nil, err } - rows, err := b.DB.Query(query, params...) + rows, err := b.DB.QueryContext(ctx, query, params...) if err != nil && err != sql.ErrNoRows { close(ch) return nil, fmt.Errorf("failed to fetch events using query %q: %w", query, err) @@ -37,7 +37,10 @@ func (b MySQLBackend) QueryEvents(ctx context.Context, filter nostr.Filter) (ch return } evt.CreatedAt = nostr.Timestamp(timestamp) - ch <- &evt + select { + case ch <- &evt: + case <-ctx.Done(): + } } }() @@ -51,7 +54,7 @@ func (b MySQLBackend) CountEvents(ctx context.Context, filter nostr.Filter) (int } var count int64 - if err = b.DB.QueryRow(query, params...).Scan(&count); err != nil && err != sql.ErrNoRows { + if err = b.DB.QueryRowContext(ctx, query, params...).Scan(&count); err != nil && err != sql.ErrNoRows { return 0, fmt.Errorf("failed to fetch events using query %q: %w", query, err) } return count, nil diff --git a/postgresql/query.go b/postgresql/query.go index 72dfb39..624cf9e 100644 --- a/postgresql/query.go +++ b/postgresql/query.go @@ -16,7 +16,7 @@ func (b PostgresBackend) QueryEvents(ctx context.Context, filter nostr.Filter) ( return nil, err } - rows, err := b.DB.Query(query, params...) + rows, err := b.DB.QueryContext(ctx, query, params...) if err != nil && err != sql.ErrNoRows { return nil, fmt.Errorf("failed to fetch events using query %q: %w", query, err) } @@ -34,7 +34,10 @@ func (b PostgresBackend) QueryEvents(ctx context.Context, filter nostr.Filter) ( return } evt.CreatedAt = nostr.Timestamp(timestamp) - ch <- &evt + select { + case ch <- &evt: + case <-ctx.Done(): + } } }() @@ -48,7 +51,7 @@ func (b PostgresBackend) CountEvents(ctx context.Context, filter nostr.Filter) ( } var count int64 - if err = b.DB.QueryRow(query, params...).Scan(&count); err != nil && err != sql.ErrNoRows { + if err = b.DB.QueryRowContext(ctx, query, params...).Scan(&count); err != nil && err != sql.ErrNoRows { return 0, fmt.Errorf("failed to fetch events using query %q: %w", query, err) } return count, nil diff --git a/sqlite3/query.go b/sqlite3/query.go index cac82f6..0b3eda0 100644 --- a/sqlite3/query.go +++ b/sqlite3/query.go @@ -16,7 +16,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter nostr.Filter) (c return nil, err } - rows, err := b.DB.Query(query, params...) + rows, err := b.DB.QueryContext(ctx, query, params...) if err != nil && err != sql.ErrNoRows { return nil, fmt.Errorf("failed to fetch events using query %q: %w", query, err) } @@ -34,7 +34,10 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter nostr.Filter) (c return } evt.CreatedAt = nostr.Timestamp(timestamp) - ch <- &evt + select { + case ch <- &evt: + case <-ctx.Done(): + } } }() @@ -48,7 +51,7 @@ func (b SQLite3Backend) CountEvents(ctx context.Context, filter nostr.Filter) (i } var count int64 - if err = b.DB.QueryRow(query, params...).Scan(&count); err != nil && err != sql.ErrNoRows { + if err = b.DB.QueryRowContext(ctx, query, params...).Scan(&count); err != nil && err != sql.ErrNoRows { return 0, fmt.Errorf("failed to fetch events using query %q: %w", query, err) } return count, nil