diff --git a/internal/service/http_test.go b/internal/service/http_test.go new file mode 100644 index 0000000..83a0087 --- /dev/null +++ b/internal/service/http_test.go @@ -0,0 +1,67 @@ +package service + +import ( + "io" + "net/http" + "strings" + "testing" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func TestUserAgentTransportSetsConfiguredUserAgent(t *testing.T) { + transport := &userAgentTransport{ + base: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if got := req.Header.Get("User-Agent"); got != "feedreader/0.1" { + t.Fatalf("unexpected user-agent: %q", got) + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: make(http.Header), + }, nil + }), + userAgent: "feedreader/0.1", + } + + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip: %v", err) + } + _ = resp.Body.Close() +} + +func TestUserAgentTransportPreservesExplicitUserAgent(t *testing.T) { + transport := &userAgentTransport{ + base: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if got := req.Header.Get("User-Agent"); got != "custom-agent/1.0" { + t.Fatalf("unexpected user-agent: %q", got) + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: make(http.Header), + }, nil + }), + userAgent: "feedreader/0.1", + } + + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + req.Header.Set("User-Agent", "custom-agent/1.0") + resp, err := transport.RoundTrip(req) + if err != nil { + t.Fatalf("round trip: %v", err) + } + _ = resp.Body.Close() +} diff --git a/internal/service/service.go b/internal/service/service.go index c924a6d..d8c9dcf 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -27,12 +27,39 @@ func New(cfg config.Config, repo *repository.SQLiteRepository) *FeedService { cfg: cfg, repo: repo, sources: sources.Build(), - client: &http.Client{ - Timeout: time.Duration(cfg.RequestTimeoutSec * float64(time.Second)), + client: newHTTPClient(cfg), + } +} + +func newHTTPClient(cfg config.Config) *http.Client { + return &http.Client{ + Timeout: time.Duration(cfg.RequestTimeoutSec * float64(time.Second)), + Transport: &userAgentTransport{ + base: http.DefaultTransport, + userAgent: strings.TrimSpace(cfg.UserAgent), }, } } +type userAgentTransport struct { + base http.RoundTripper + userAgent string +} + +func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { + transport := t.base + if transport == nil { + transport = http.DefaultTransport + } + if strings.TrimSpace(t.userAgent) == "" || req.Header.Get("User-Agent") != "" { + return transport.RoundTrip(req) + } + clone := req.Clone(req.Context()) + clone.Header = req.Header.Clone() + clone.Header.Set("User-Agent", t.userAgent) + return transport.RoundTrip(clone) +} + func (s *FeedService) StartScheduler(ctx context.Context) { go func() { location := loadScheduleLocation() diff --git a/internal/sources/alphaxiv.go b/internal/sources/alphaxiv.go index 5d66ea5..59e334d 100644 --- a/internal/sources/alphaxiv.go +++ b/internal/sources/alphaxiv.go @@ -21,11 +21,7 @@ func (AlphaXivSource) Label() string { return "alphaXiv" } func (AlphaXivSource) HomePageURL() string { return "https://www.alphaxiv.org/" } func (s AlphaXivSource) Fetch(ctx context.Context, client *http.Client) ([]domain.FeedItem, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.HomePageURL(), nil) - if err != nil { - return nil, err - } - resp, err := client.Do(req) + resp, err := getWithRetry(ctx, client, s.HomePageURL()) if err != nil { return nil, err } diff --git a/internal/sources/github.go b/internal/sources/github.go index 8379dbe..2d99e2a 100644 --- a/internal/sources/github.go +++ b/internal/sources/github.go @@ -19,11 +19,7 @@ func (GitHubTrendingSource) Label() string { return "GitHub Trending" } func (GitHubTrendingSource) HomePageURL() string { return "https://github.com/trending" } func (s GitHubTrendingSource) Fetch(ctx context.Context, client *http.Client) ([]domain.FeedItem, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.HomePageURL(), nil) - if err != nil { - return nil, err - } - resp, err := client.Do(req) + resp, err := getWithRetry(ctx, client, s.HomePageURL()) if err != nil { return nil, err } diff --git a/internal/sources/hackernews.go b/internal/sources/hackernews.go index bafebd8..bb23f89 100644 --- a/internal/sources/hackernews.go +++ b/internal/sources/hackernews.go @@ -2,7 +2,7 @@ package sources import ( "context" - "encoding/xml" + "encoding/json" "html" "io" "net/http" @@ -14,6 +14,8 @@ import ( "feedreader/internal/domain" ) +const hackerNewsFrontPageAPI = "https://hn.algolia.com/api/v1/search?tags=front_page" + type HackerNewsSource struct{} func (HackerNewsSource) Key() string { return "hackernews" } @@ -21,11 +23,7 @@ func (HackerNewsSource) Label() string { return "Hacker News" } func (HackerNewsSource) HomePageURL() string { return "https://news.ycombinator.com/" } func (s HackerNewsSource) Fetch(ctx context.Context, client *http.Client) ([]domain.FeedItem, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://hnrss.org/frontpage", nil) - if err != nil { - return nil, err - } - resp, err := client.Do(req) + resp, err := getWithRetry(ctx, client, hackerNewsFrontPageAPI) if err != nil { return nil, err } @@ -41,52 +39,54 @@ func (s HackerNewsSource) Fetch(ctx context.Context, client *http.Client) ([]dom return parseHackerNews(body) } -type hnRSS struct { - Channel struct { - Items []hnItem `xml:"item"` - } `xml:"channel"` +type hnFrontPage struct { + Hits []hnStory `json:"hits"` } -type hnItem struct { - Title string `xml:"title"` - Description string `xml:"description"` - PubDate string `xml:"pubDate"` - Link string `xml:"link"` - Comments string `xml:"comments"` - Guid string `xml:"guid"` - Creator string `xml:"creator"` +type hnStory struct { + ObjectID string `json:"objectID"` + StoryID int `json:"story_id"` + Title string `json:"title"` + StoryTitle string `json:"story_title"` + URL string `json:"url"` + StoryURL string `json:"story_url"` + StoryText string `json:"story_text"` + CommentText string `json:"comment_text"` + Author string `json:"author"` + Points *int `json:"points"` + NumComments *int `json:"num_comments"` + CreatedAt string `json:"created_at"` } func parseHackerNews(payload []byte) ([]domain.FeedItem, error) { - var rss hnRSS - if err := xml.Unmarshal(payload, &rss); err != nil { + var rss hnFrontPage + if err := json.Unmarshal(payload, &rss); err != nil { return nil, err } - items := make([]domain.FeedItem, 0, len(rss.Channel.Items)) - for idx, node := range rss.Channel.Items { - var publishedAt *time.Time - if node.PubDate != "" { - if parsed, err := time.Parse(time.RFC1123Z, node.PubDate); err == nil { - t := parsed.UTC() - publishedAt = &t - } + items := make([]domain.FeedItem, 0, len(rss.Hits)) + for idx, node := range rss.Hits { + externalID := strings.TrimSpace(node.ObjectID) + if externalID == "" && node.StoryID > 0 { + externalID = strconv.Itoa(node.StoryID) } - score := extractInt(node.Description, `Points:\s*(\d+)`) - commentsCount := extractInt(node.Description, `# Comments:\s*(\d+)`) + if externalID == "" { + continue + } + commentsURL := "https://news.ycombinator.com/item?id=" + externalID metadata := map[string]any{} - if commentsCount != nil { - metadata["comments_count"] = *commentsCount + if node.NumComments != nil { + metadata["comments_count"] = *node.NumComments } items = append(items, domain.FeedItem{ Source: "hackernews", - ExternalID: extractStoryID(firstNonEmpty(node.Comments, node.Guid, node.Link)), - Title: strings.TrimSpace(node.Title), - URL: strings.TrimSpace(node.Link), - Summary: cleanString(extractHNSummary(node.Description)), - Author: cleanString(strings.TrimSpace(node.Creator)), - Score: score, - CommentsURL: cleanString(strings.TrimSpace(node.Comments)), - PublishedAt: publishedAt, + ExternalID: externalID, + Title: strings.TrimSpace(firstNonEmpty(node.Title, node.StoryTitle, externalID)), + URL: strings.TrimSpace(firstNonEmpty(node.URL, node.StoryURL, commentsURL)), + Summary: cleanString(extractHNSummary(firstNonEmpty(node.StoryText, node.CommentText))), + Author: cleanString(strings.TrimSpace(node.Author)), + Score: node.Points, + CommentsURL: cleanString(commentsURL), + PublishedAt: parseHackerNewsTime(node.CreatedAt), SourceRank: idx + 1, Metadata: metadata, }) @@ -94,62 +94,21 @@ func parseHackerNews(payload []byte) ([]domain.FeedItem, error) { return items, nil } -func extractStoryID(value string) string { - re := regexp.MustCompile(`id=(\d+)`) - if match := re.FindStringSubmatch(value); len(match) == 2 { - return match[1] - } - return strings.TrimSpace(value) -} - -func extractHNSummary(description string) string { - head := strings.SplitN(description, "
", 2)[0] - replacer := regexp.MustCompile(`]+>||<[^>]+>`) - cleaned := replacer.ReplaceAllString(head, " ") - cleaned = html.UnescapeString(cleaned) - patterns := []*regexp.Regexp{ - regexp.MustCompile(`Comments URL:\s*\S+`), - regexp.MustCompile(`Article URL:\s*\S+`), - regexp.MustCompile(`Points:\s*\d+`), - regexp.MustCompile(`# Comments:\s*\d+`), - regexp.MustCompile(`\s+`), - } - for _, pattern := range patterns { - cleaned = pattern.ReplaceAllString(cleaned, " ") - } - return strings.TrimSpace(cleaned) -} - -func extractInt(value, pattern string) *int { - re := regexp.MustCompile(pattern) - match := re.FindStringSubmatch(value) - if len(match) != 2 { - return nil - } - parsed := strings.ReplaceAll(match[1], ",", "") - if parsed == "" { +func parseHackerNewsTime(value string) *time.Time { + if strings.TrimSpace(value) == "" { return nil } - out, err := strconv.Atoi(parsed) + parsed, err := time.Parse(time.RFC3339, strings.TrimSpace(value)) if err != nil { return nil } - return &out + utc := parsed.UTC() + return &utc } -func firstNonEmpty(values ...string) string { - for _, value := range values { - if strings.TrimSpace(value) != "" { - return value - } - } - return "" -} - -func cleanString(value string) *string { - value = strings.TrimSpace(value) - if value == "" { - return nil - } - return &value +func extractHNSummary(description string) string { + cleaned := regexp.MustCompile(`]+>||<[^>]+>`).ReplaceAllString(description, " ") + cleaned = html.UnescapeString(cleaned) + cleaned = regexp.MustCompile(`\s+`).ReplaceAllString(cleaned, " ") + return strings.TrimSpace(cleaned) } diff --git a/internal/sources/hackernews_test.go b/internal/sources/hackernews_test.go new file mode 100644 index 0000000..92edc5d --- /dev/null +++ b/internal/sources/hackernews_test.go @@ -0,0 +1,83 @@ +package sources + +import ( + "testing" + "time" +) + +func TestParseHackerNewsFrontPageJSON(t *testing.T) { + payload := []byte(`{ + "hits": [ + { + "objectID": "123", + "title": "Example story", + "url": "https://example.com/story", + "author": "alice", + "points": 42, + "num_comments": 11, + "created_at": "2026-06-21T12:44:13Z", + "story_text": "

Example & summary.

" + }, + { + "objectID": "456", + "title": "Ask HN: Fallback URL", + "author": "bob", + "points": 7, + "num_comments": 3, + "created_at": "2026-06-21T13:00:00Z" + } + ] + }`) + + items, err := parseHackerNews(payload) + if err != nil { + t.Fatalf("parse hacker news payload: %v", err) + } + if len(items) != 2 { + t.Fatalf("expected 2 items, got %d", len(items)) + } + + first := items[0] + if first.Source != "hackernews" { + t.Fatalf("unexpected source: %q", first.Source) + } + if first.ExternalID != "123" { + t.Fatalf("unexpected external id: %q", first.ExternalID) + } + if first.Title != "Example story" { + t.Fatalf("unexpected title: %q", first.Title) + } + if first.URL != "https://example.com/story" { + t.Fatalf("unexpected url: %q", first.URL) + } + if first.Author == nil || *first.Author != "alice" { + t.Fatalf("unexpected author: %#v", first.Author) + } + if first.Score == nil || *first.Score != 42 { + t.Fatalf("unexpected score: %#v", first.Score) + } + if first.Summary == nil || *first.Summary != "Example & summary." { + t.Fatalf("unexpected summary: %#v", first.Summary) + } + if first.CommentsURL == nil || *first.CommentsURL != "https://news.ycombinator.com/item?id=123" { + t.Fatalf("unexpected comments url: %#v", first.CommentsURL) + } + if got, ok := first.Metadata["comments_count"].(int); !ok || got != 11 { + t.Fatalf("unexpected comments_count metadata: %#v", first.Metadata["comments_count"]) + } + wantPublishedAt := time.Date(2026, time.June, 21, 12, 44, 13, 0, time.UTC) + if first.PublishedAt == nil || !first.PublishedAt.Equal(wantPublishedAt) { + t.Fatalf("unexpected publishedAt: %#v", first.PublishedAt) + } + if first.SourceRank != 1 { + t.Fatalf("unexpected source rank: %d", first.SourceRank) + } + + second := items[1] + if second.URL != "https://news.ycombinator.com/item?id=456" { + t.Fatalf("expected comments-url fallback, got %q", second.URL) + } + if second.CommentsURL == nil || *second.CommentsURL != "https://news.ycombinator.com/item?id=456" { + t.Fatalf("unexpected fallback comments url: %#v", second.CommentsURL) + } +} diff --git a/internal/sources/http_test.go b/internal/sources/http_test.go new file mode 100644 index 0000000..cfc0ea8 --- /dev/null +++ b/internal/sources/http_test.go @@ -0,0 +1,108 @@ +package sources + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + "time" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func TestGetWithRetryRetriesTransportError(t *testing.T) { + original := upstreamRetryDelays + upstreamRetryDelays = []time.Duration{0, 0} + defer func() { upstreamRetryDelays = original }() + + attempts := 0 + client := &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + attempts++ + if attempts < 3 { + return nil, errors.New("connection reset by peer") + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("ok")), + Header: make(http.Header), + }, nil + })} + + resp, err := getWithRetry(context.Background(), client, "https://example.com") + if err != nil { + t.Fatalf("getWithRetry returned error: %v", err) + } + defer resp.Body.Close() + if attempts != 3 { + t.Fatalf("expected 3 attempts, got %d", attempts) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } +} + +func TestGetWithRetryRetriesRetryableStatus(t *testing.T) { + original := upstreamRetryDelays + upstreamRetryDelays = []time.Duration{0} + defer func() { upstreamRetryDelays = original }() + + attempts := 0 + client := &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + attempts++ + status := http.StatusBadGateway + if attempts == 2 { + status = http.StatusOK + } + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader(http.StatusText(status))), + Header: make(http.Header), + }, nil + })} + + resp, err := getWithRetry(context.Background(), client, "https://example.com") + if err != nil { + t.Fatalf("getWithRetry returned error: %v", err) + } + defer resp.Body.Close() + if attempts != 2 { + t.Fatalf("expected 2 attempts, got %d", attempts) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } +} + +func TestGetWithRetryDoesNotRetryNonRetryableStatus(t *testing.T) { + original := upstreamRetryDelays + upstreamRetryDelays = []time.Duration{time.Hour} + defer func() { upstreamRetryDelays = original }() + + attempts := 0 + client := &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + attempts++ + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(strings.NewReader("missing")), + Header: make(http.Header), + }, nil + })} + + resp, err := getWithRetry(context.Background(), client, "https://example.com") + if err != nil { + t.Fatalf("getWithRetry returned error: %v", err) + } + defer resp.Body.Close() + if attempts != 1 { + t.Fatalf("expected 1 attempt, got %d", attempts) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } +} diff --git a/internal/sources/huggingface.go b/internal/sources/huggingface.go index 8ac82d8..ce8cd1d 100644 --- a/internal/sources/huggingface.go +++ b/internal/sources/huggingface.go @@ -21,11 +21,7 @@ func (HuggingFacePapersSource) Label() string { return "Hugging Face Paper func (HuggingFacePapersSource) HomePageURL() string { return "https://huggingface.co/papers/trending" } func (s HuggingFacePapersSource) Fetch(ctx context.Context, client *http.Client) ([]domain.FeedItem, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.HomePageURL(), nil) - if err != nil { - return nil, err - } - resp, err := client.Do(req) + resp, err := getWithRetry(ctx, client, s.HomePageURL()) if err != nil { return nil, err } diff --git a/internal/sources/util.go b/internal/sources/util.go index 543c81d..0e056bb 100644 --- a/internal/sources/util.go +++ b/internal/sources/util.go @@ -1,6 +1,17 @@ package sources -import "fmt" +import ( + "context" + "fmt" + "io" + "net/http" + "regexp" + "strconv" + "strings" + "time" +) + +var upstreamRetryDelays = []time.Duration{250 * time.Millisecond, 750 * time.Millisecond} type httpError struct { StatusCode int @@ -13,3 +24,107 @@ func (e *httpError) Error() string { } return fmt.Sprintf("unexpected status %d: %s", e.StatusCode, e.Body) } + +func getWithRetry(ctx context.Context, client *http.Client, rawURL string) (*http.Response, error) { + if client == nil { + client = http.DefaultClient + } + attempts := len(upstreamRetryDelays) + 1 + for attempt := 1; attempt <= attempts; attempt++ { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return nil, err + } + resp, err := client.Do(req) + if err == nil { + if !shouldRetryStatus(resp.StatusCode) || attempt == attempts { + return resp, nil + } + drainAndClose(resp.Body) + } else if ctx.Err() != nil || attempt == attempts { + return nil, err + } + if err := sleepWithContext(ctx, upstreamRetryDelay(attempt)); err != nil { + return nil, err + } + } + return nil, fmt.Errorf("upstream retry loop exhausted for %s", rawURL) +} + +func upstreamRetryDelay(attempt int) time.Duration { + if attempt <= 0 || attempt > len(upstreamRetryDelays) { + return 0 + } + return upstreamRetryDelays[attempt-1] +} + +func shouldRetryStatus(statusCode int) bool { + switch statusCode { + case http.StatusRequestTimeout, + http.StatusTooEarly, + http.StatusTooManyRequests, + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout: + return true + default: + return false + } +} + +func drainAndClose(body io.ReadCloser) { + if body == nil { + return + } + _, _ = io.Copy(io.Discard, io.LimitReader(body, 4096)) + _ = body.Close() +} + +func sleepWithContext(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + return nil + } + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func extractInt(value, pattern string) *int { + re := regexp.MustCompile(pattern) + match := re.FindStringSubmatch(value) + if len(match) != 2 { + return nil + } + parsed := strings.ReplaceAll(match[1], ",", "") + if parsed == "" { + return nil + } + out, err := strconv.Atoi(parsed) + if err != nil { + return nil + } + return &out +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + +func cleanString(value string) *string { + value = strings.TrimSpace(value) + if value == "" { + return nil + } + return &value +}