From 284d5e19f683054547f5b40d8741cf59adfd5bde Mon Sep 17 00:00:00 2001 From: Mathias Bergqvist Date: Wed, 6 May 2026 22:48:02 +0200 Subject: [PATCH] feat(tools): pr_merge --- internal/gitea/pulls.go | 19 +++++++++ internal/gitea/pulls_test.go | 34 +++++++++++++++ internal/tools/pr_merge.go | 76 +++++++++++++++++++++++++++++++++ internal/tools/pr_merge_test.go | 70 ++++++++++++++++++++++++++++++ 4 files changed, 199 insertions(+) create mode 100644 internal/tools/pr_merge.go create mode 100644 internal/tools/pr_merge_test.go diff --git a/internal/gitea/pulls.go b/internal/gitea/pulls.go index 694222c..529e2c9 100644 --- a/internal/gitea/pulls.go +++ b/internal/gitea/pulls.go @@ -103,6 +103,25 @@ func (c *Client) GetPullRequestDiff(ctx context.Context, owner, repo string, ind return resp.Body, nil } +type MergePRArgs struct { + Do string `json:"Do"` + Title string `json:"merge_message_title,omitempty"` + Body string `json:"merge_message_field,omitempty"` +} + +func (c *Client) MergePullRequest(ctx context.Context, owner, repo string, index int, args MergePRArgs) error { + p := fmt.Sprintf("/api/v1/repos/%s/%s/pulls/%d/merge", owner, repo, index) + payload, err := json.Marshal(args) + if err != nil { + return err + } + body, status, err := c.PostJSON(ctx, p, payload) + if err != nil { + return err + } + return MapStatus(status, body) +} + func (c *Client) ListPullRequests(ctx context.Context, owner, repo, state, head string, page, limit int) ([]PullRequest, error) { if page < 1 { page = 1 diff --git a/internal/gitea/pulls_test.go b/internal/gitea/pulls_test.go index c2c6f75..4fd8c58 100644 --- a/internal/gitea/pulls_test.go +++ b/internal/gitea/pulls_test.go @@ -154,3 +154,37 @@ func TestListPullRequests(t *testing.T) { assert.Equal(t, 7, prs[0].Number) assert.Equal(t, "feat/x", prs[0].Head.Ref) } + +func TestMergePullRequestSuccess(t *testing.T) { + var captured []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/v1/repos/o/r/pulls/7/merge", r.URL.Path) + assert.Equal(t, http.MethodPost, r.Method) + var err error + captured, err = io.ReadAll(r.Body) + require.NoError(t, err) + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + c := gitea.NewClient(srv.URL, "tok") + err := c.MergePullRequest(context.Background(), "o", "r", 7, gitea.MergePRArgs{Do: "squash"}) + require.NoError(t, err) + + var payload map[string]any + require.NoError(t, json.Unmarshal(captured, &payload)) + assert.Equal(t, "squash", payload["Do"]) +} + +func TestMergePullRequestConflict(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusConflict) + _, _ = w.Write([]byte(`{"message":"merge conflict"}`)) + })) + defer srv.Close() + + c := gitea.NewClient(srv.URL, "tok") + err := c.MergePullRequest(context.Background(), "o", "r", 7, gitea.MergePRArgs{Do: "merge"}) + require.Error(t, err) + assert.ErrorIs(t, err, gitea.ErrConflict) +} diff --git a/internal/tools/pr_merge.go b/internal/tools/pr_merge.go new file mode 100644 index 0000000..6e6da4a --- /dev/null +++ b/internal/tools/pr_merge.go @@ -0,0 +1,76 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + + "gitea.d-ma.be/mathias/gitea-mcp/internal/allowlist" + "gitea.d-ma.be/mathias/gitea-mcp/internal/gitea" + "gitea.d-ma.be/mathias/gitea-mcp/internal/registry" +) + +type PRMerge struct { + c *gitea.Client + a *allowlist.Allowlist +} + +func NewPRMerge(c *gitea.Client, a *allowlist.Allowlist) *PRMerge { + return &PRMerge{c: c, a: a} +} + +func (t *PRMerge) Descriptor() registry.ToolDescriptor { + return registry.ToolDescriptor{ + Name: "pr_merge", + Description: "Merge a pull request. style: merge (default), squash, or rebase.", + InputSchema: json.RawMessage(`{ + "type":"object", + "properties":{ + "owner":{"type":"string"}, + "name":{"type":"string"}, + "index":{"type":"integer","minimum":1}, + "style":{"type":"string","enum":["merge","squash","rebase"]}, + "merge_message_title":{"type":"string"}, + "merge_message_field":{"type":"string"} + }, + "required":["owner","name","index"] + }`), + } +} + +type prMergeArgs struct { + Owner string `json:"owner"` + Name string `json:"name"` + Index int `json:"index"` + Style string `json:"style"` + Title string `json:"merge_message_title"` + Body string `json:"merge_message_field"` +} + +func (t *PRMerge) Call(ctx context.Context, raw json.RawMessage) (json.RawMessage, error) { + var args prMergeArgs + if err := parseArgs(raw, &args); err != nil { + return nil, err + } + if err := t.a.Check(args.Owner); err != nil { + return nil, err + } + if args.Index < 1 { + return nil, fmt.Errorf("index must be >= 1: %w", gitea.ErrValidation) + } + + style := args.Style + if style == "" { + style = "merge" + } + + if err := t.c.MergePullRequest(ctx, args.Owner, args.Name, args.Index, gitea.MergePRArgs{ + Do: style, + Title: args.Title, + Body: args.Body, + }); err != nil { + return nil, err + } + + return textOK(map[string]any{"merged": true}) +} diff --git a/internal/tools/pr_merge_test.go b/internal/tools/pr_merge_test.go new file mode 100644 index 0000000..3b586c9 --- /dev/null +++ b/internal/tools/pr_merge_test.go @@ -0,0 +1,70 @@ +package tools_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "gitea.d-ma.be/mathias/gitea-mcp/internal/allowlist" + "gitea.d-ma.be/mathias/gitea-mcp/internal/gitea" + "gitea.d-ma.be/mathias/gitea-mcp/internal/tools" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPRMergeSuccess(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/v1/repos/owner/repo/pulls/7/merge", r.URL.Path) + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + tool := tools.NewPRMerge(gitea.NewClient(srv.URL, "tok"), allowlist.New([]string{"owner"})) + out, err := tool.Call(context.Background(), json.RawMessage(`{"owner":"owner","name":"repo","index":7}`)) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(out, &result)) + assert.Equal(t, true, result["merged"]) +} + +func TestPRMergeDefaultsToMergeStyle(t *testing.T) { + var captured []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + captured, err = io.ReadAll(r.Body) + require.NoError(t, err) + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + tool := tools.NewPRMerge(gitea.NewClient(srv.URL, "tok"), allowlist.New([]string{"owner"})) + _, err := tool.Call(context.Background(), json.RawMessage(`{"owner":"owner","name":"repo","index":7}`)) + require.NoError(t, err) + + var payload map[string]any + require.NoError(t, json.Unmarshal(captured, &payload)) + assert.Equal(t, "merge", payload["Do"]) +} + +func TestPRMergeConflictReturnsError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusConflict) + _, _ = w.Write([]byte(`{"message":"merge conflict"}`)) + })) + defer srv.Close() + + tool := tools.NewPRMerge(gitea.NewClient(srv.URL, "tok"), allowlist.New([]string{"owner"})) + _, err := tool.Call(context.Background(), json.RawMessage(`{"owner":"owner","name":"repo","index":7}`)) + require.Error(t, err) + assert.ErrorIs(t, err, gitea.ErrConflict) +} + +func TestPRMergeAllowlistRejects(t *testing.T) { + tool := tools.NewPRMerge(gitea.NewClient("http://unused", ""), allowlist.New([]string{"allowed"})) + _, err := tool.Call(context.Background(), json.RawMessage(`{"owner":"evil","name":"repo","index":1}`)) + require.Error(t, err) +}