diff --git a/.github/workflows/wasmbus-go.yaml b/.github/workflows/wasmbus-go.yaml index e5f7a0a5..5e85e507 100644 --- a/.github/workflows/wasmbus-go.yaml +++ b/.github/workflows/wasmbus-go.yaml @@ -78,6 +78,18 @@ jobs: working-directory: x/wasmbus/events run: go test -cover -v -wash-output + - name: wasmbus/policy + working-directory: x/wasmbus/policy + run: go test -cover -v -wash-output + + - name: wasmbus/config + working-directory: x/wasmbus/config + run: go test -cover -v -wash-output + + - name: wasmbus/secrets + working-directory: x/wasmbus/secrets + run: go test -cover -v -wash-output + examples: # Context: https://github.com/golangci/golangci-lint-action/blob/v6.1.1/README.md#annotations permissions: diff --git a/x/wasmbus/bus.go b/x/wasmbus/bus.go index da179ac8..24430281 100644 --- a/x/wasmbus/bus.go +++ b/x/wasmbus/bus.go @@ -22,6 +22,10 @@ const ( PrefixEvents = "wasmbus.evt" // PrefixControl is the prefix for Lattice RPC. PrefixCtlV1 = "wasmbus.ctl.v1" + + PrefixConfig = "wasmbus.cfg" + + PrefixSecrets = "wasmcloud.secrets" ) var ( diff --git a/x/wasmbus/config/api.go b/x/wasmbus/config/api.go new file mode 100644 index 00000000..63fb163b --- /dev/null +++ b/x/wasmbus/config/api.go @@ -0,0 +1,39 @@ +package config + +import ( + "context" + "fmt" +) + +var ( + ErrProtocol = fmt.Errorf("encoding error") + ErrInternal = fmt.Errorf("internal error") +) + +type API interface { + // Host is currently the only method exposed by the API. + Host(ctx context.Context, req *HostRequest) (*HostResponse, error) +} + +var _ API = (*APIMock)(nil) + +type APIMock struct { + HostFunc func(ctx context.Context, req *HostRequest) (*HostResponse, error) +} + +func (m *APIMock) Host(ctx context.Context, req *HostRequest) (*HostResponse, error) { + return m.HostFunc(ctx, req) +} + +type HostRequest struct { + Labels map[string]string `json:"labels"` +} + +type HostResponse struct { + RegistryCredentials map[string]RegistryCredential `json:"registryCredentials,omitempty"` +} + +type RegistryCredential struct { + Username string `json:"username"` + Password string `json:"password"` +} diff --git a/x/wasmbus/config/server.go b/x/wasmbus/config/server.go new file mode 100644 index 00000000..444e85a5 --- /dev/null +++ b/x/wasmbus/config/server.go @@ -0,0 +1,27 @@ +package config + +import ( + "fmt" + + "go.wasmcloud.dev/x/wasmbus" +) + +type Server struct { + *wasmbus.Server + Lattice string + api API +} + +func NewServer(bus wasmbus.Bus, lattice string, api API) *Server { + return &Server{ + Server: wasmbus.NewServer(bus), + Lattice: lattice, + api: api, + } +} + +func (s *Server) Serve() error { + subject := fmt.Sprintf("%s.%s.req", wasmbus.PrefixConfig, s.Lattice) + handler := wasmbus.NewRequestHandler(HostRequest{}, HostResponse{}, s.api.Host) + return s.RegisterHandler(subject, handler) +} diff --git a/x/wasmbus/config/server_test.go b/x/wasmbus/config/server_test.go new file mode 100644 index 00000000..f6972e4e --- /dev/null +++ b/x/wasmbus/config/server_test.go @@ -0,0 +1,68 @@ +package config + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/nats-io/nats.go" + "go.wasmcloud.dev/x/wasmbus" + "go.wasmcloud.dev/x/wasmbus/wasmbustest" +) + +func TestServer(t *testing.T) { + defer wasmbustest.MustStartNats(t)() + + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatalf("failed to connect to nats: %v", err) + } + bus := wasmbus.NewNatsBus(nc) + s := NewServer(bus, "test", &APIMock{ + HostFunc: func(ctx context.Context, req *HostRequest) (*HostResponse, error) { + return &HostResponse{ + RegistryCredentials: map[string]RegistryCredential{ + "docker.io": { + Username: "my-username", + Password: "hunter2", + }, + }, + }, nil + }, + }) + if err := s.Serve(); err != nil { + t.Fatalf("failed to start server: %v", err) + } + + req := wasmbus.NewMessage(fmt.Sprintf("%s.%s.req", wasmbus.PrefixConfig, "test")) + req.Data = []byte(`{"labels":{"hostcore.arch":"aarch64","hostcore.os":"linux","hostcore.osfamily":"unix","kubernetes":"true","kubernetes.hostgroup":"default"}}`) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + rawResp, err := bus.Request(ctx, req) + if err != nil { + t.Fatal(err) + } + + var resp HostResponse + if err := wasmbus.Decode(rawResp, &resp); err != nil { + t.Fatal(err) + } + + docker, ok := resp.RegistryCredentials["docker.io"] + if !ok { + t.Fatalf("expected docker.io registry credentials") + } + if want, got := "my-username", docker.Username; want != got { + t.Fatalf("expected username %q, got %q", want, got) + } + + if want, got := "hunter2", docker.Password; want != got { + t.Fatalf("expected password %q, got %q", want, got) + } + + if err := s.Drain(); err != nil { + t.Fatalf("failed to drain server: %v", err) + } +} diff --git a/x/wasmbus/go.mod b/x/wasmbus/go.mod index 35172a0b..08014e43 100644 --- a/x/wasmbus/go.mod +++ b/x/wasmbus/go.mod @@ -5,8 +5,10 @@ go 1.23.3 require ( github.com/cloudevents/sdk-go/v2 v2.15.2 github.com/goccy/go-yaml v1.15.13 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/nats-io/nats-server/v2 v2.10.24 github.com/nats-io/nats.go v1.38.0 + github.com/nats-io/nkeys v0.4.9 ) require ( @@ -17,7 +19,6 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/nats-io/jwt/v2 v2.7.3 // indirect - github.com/nats-io/nkeys v0.4.9 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/stretchr/testify v1.10.0 // indirect go.uber.org/atomic v1.11.0 // indirect diff --git a/x/wasmbus/go.sum b/x/wasmbus/go.sum index 0c7cb850..4b881fbb 100644 --- a/x/wasmbus/go.sum +++ b/x/wasmbus/go.sum @@ -5,6 +5,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/goccy/go-yaml v1.15.13 h1:Xd87Yddmr2rC1SLLTm2MNDcTjeO/GYo0JGiww6gSTDg= github.com/goccy/go-yaml v1.15.13/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= diff --git a/x/wasmbus/policy/api.go b/x/wasmbus/policy/api.go new file mode 100644 index 00000000..23a09a15 --- /dev/null +++ b/x/wasmbus/policy/api.go @@ -0,0 +1,122 @@ +package policy + +import ( + "context" + "fmt" +) + +var ( + ErrProtocol = fmt.Errorf("encoding error") + ErrInternal = fmt.Errorf("internal error") +) + +type API interface { + // PerformInvocation is called when a component is invoked + PerformInvocation(ctx context.Context, req *PerformInvocationRequest) (*Response, error) + // StartComponent is called when a component is started + StartComponent(ctx context.Context, req *StartComponentRequest) (*Response, error) + // StartProvider is called when a provider is started + StartProvider(ctx context.Context, req *StartProviderRequest) (*Response, error) +} + +var _ API = (*APIMock)(nil) + +type APIMock struct { + PerformInvocationFunc func(ctx context.Context, req *PerformInvocationRequest) (*Response, error) + StartComponentFunc func(ctx context.Context, req *StartComponentRequest) (*Response, error) + StartProviderFunc func(ctx context.Context, req *StartProviderRequest) (*Response, error) +} + +func (m *APIMock) PerformInvocation(ctx context.Context, req *PerformInvocationRequest) (*Response, error) { + return m.PerformInvocationFunc(ctx, req) +} + +func (m *APIMock) StartComponent(ctx context.Context, req *StartComponentRequest) (*Response, error) { + return m.StartComponentFunc(ctx, req) +} + +func (m *APIMock) StartProvider(ctx context.Context, req *StartProviderRequest) (*Response, error) { + return m.StartProviderFunc(ctx, req) +} + +// Request is the structure of the request sent to the policy engine +type BaseRequest[T any] struct { + Id string `json:"requestId"` + Kind string `json:"kind"` + Version string `json:"version"` + Host Host `json:"host"` + Request T `json:"request"` +} + +// Decision is a helper function to create a response +func (r BaseRequest[T]) Decision(allowed bool, msg string) *Response { + return &Response{ + Id: r.Id, + Permitted: allowed, + Message: msg, + } +} + +// Deny is a helper function to create a response with a deny decision +func (r BaseRequest[T]) Deny(msg string) *Response { + return r.Decision(false, msg) +} + +// Allow is a helper function to create a response with an allow decision +func (r BaseRequest[T]) Allow(msg string) *Response { + return r.Decision(true, msg) +} + +// Response is the structure of the response sent by the policy engine +type Response struct { + Id string `json:"requestId"` + Permitted bool `json:"permitted"` + Message string `json:"message,omitempty"` +} + +type Claims struct { + PublicKey string `json:"publicKey"` + Issuer string `json:"issuer"` + IssuedAt int `json:"issuedAt"` + ExpiresAt int `json:"expiresAt"` + Expired bool `json:"expired"` +} + +type StartComponentPayload struct { + ComponentId string `json:"componentId"` + ImageRef string `json:"imageRef"` + MaxInstances int `json:"maxInstances"` + Annotations map[string]string `json:"annotations"` +} + +type StartComponentRequest = BaseRequest[StartComponentPayload] + +type StartProviderPayload struct { + ProviderId string `json:"providerId"` + ImageRef string `json:"imageRef"` + Annotations map[string]string `json:"annotations"` +} + +type StartProviderRequest = BaseRequest[StartProviderPayload] + +type PerformInvocationPayload struct { + Interface string `json:"interface"` + Function string `json:"function"` + // NOTE(lxf): this covers components but not providers. wut?!? + Target InvocationTarget `json:"target"` +} + +type PerformInvocationRequest = BaseRequest[PerformInvocationPayload] + +type InvocationTarget struct { + ComponentId string `json:"componentId"` + ImageRef string `json:"imageRef"` + MaxInstances int `json:"maxInstances"` + Annotations map[string]string `json:"annotations"` +} + +type Host struct { + PublicKey string `json:"publicKey"` + Lattice string `json:"lattice"` + Labels map[string]string `json:"labels"` +} diff --git a/x/wasmbus/policy/server.go b/x/wasmbus/policy/server.go new file mode 100644 index 00000000..67376bbe --- /dev/null +++ b/x/wasmbus/policy/server.go @@ -0,0 +1,59 @@ +package policy + +import ( + "context" + "encoding/json" + "fmt" + + "go.wasmcloud.dev/x/wasmbus" +) + +type Server struct { + *wasmbus.Server + subject string + api API +} + +func NewServer(bus wasmbus.Bus, subject string, api API) *Server { + return &Server{ + Server: wasmbus.NewServer(bus), + subject: subject, + api: api, + } +} + +func (s *Server) Serve() error { + handler := wasmbus.NewTypedHandler(extractType) + + startComponent := wasmbus.NewRequestHandler(StartComponentRequest{}, Response{}, s.api.StartComponent) + if err := handler.RegisterType("startComponent", startComponent); err != nil { + return err + } + + startProvider := wasmbus.NewRequestHandler(StartProviderRequest{}, Response{}, s.api.StartProvider) + if err := handler.RegisterType("startProvider", startProvider); err != nil { + return err + } + + performInvocation := wasmbus.NewRequestHandler(PerformInvocationRequest{}, Response{}, s.api.PerformInvocation) + if err := handler.RegisterType("performInvocation", performInvocation); err != nil { + return err + } + + return s.RegisterHandler(s.subject, handler) +} + +func extractType(ctx context.Context, msg *wasmbus.Message) (string, error) { + var baseReq BaseRequest[json.RawMessage] + + if err := wasmbus.Decode(msg, &baseReq); err != nil { + return "", err + } + + switch baseReq.Kind { + case "startComponent", "startProvider", "performInvocation": + return baseReq.Kind, nil + default: + return "", fmt.Errorf("unknown request kind: %s", baseReq.Kind) + } +} diff --git a/x/wasmbus/policy/server_test.go b/x/wasmbus/policy/server_test.go new file mode 100644 index 00000000..fd079163 --- /dev/null +++ b/x/wasmbus/policy/server_test.go @@ -0,0 +1,117 @@ +package policy + +import ( + "context" + "testing" + "time" + + "github.com/nats-io/nats.go" + "go.wasmcloud.dev/x/wasmbus" + "go.wasmcloud.dev/x/wasmbus/wasmbustest" +) + +func TestServer(t *testing.T) { + defer wasmbustest.MustStartNats(t)() + + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatalf("failed to connect to nats: %v", err) + } + bus := wasmbus.NewNatsBus(nc) + s := NewServer(bus, "subject.test", &APIMock{ + StartComponentFunc: func(ctx context.Context, req *StartComponentRequest) (*Response, error) { + return req.Allow("passed"), nil + }, + StartProviderFunc: func(ctx context.Context, req *StartProviderRequest) (*Response, error) { + return req.Deny("denied"), nil + }, + PerformInvocationFunc: func(ctx context.Context, req *PerformInvocationRequest) (*Response, error) { + return req.Allow("passed"), nil + }, + }) + if err := s.Serve(); err != nil { + t.Fatalf("failed to start server: %v", err) + } + + t.Run("startComponent", func(t *testing.T) { + req := wasmbus.NewMessage("subject.test") + req.Data = []byte(`{"requestId":"01945242-abec-71ee-f5e6-1f44eb61ad40","kind":"startComponent","version":"v1","request":{"componentId":"hello_world-http_component","imageRef":"ghcr.io/wasmcloud/components/http-hello-world-rust:0.1.0","maxInstances":1,"annotations":{"wasmcloud.dev/appspec":"hello-world","wasmcloud.dev/managed-by":"wadm","wasmcloud.dev/scaler":"a648fe966cbdb0a0dee3252a416f824858bbf0c1a24be850ef632626ddbb5133","wasmcloud.dev/spread_name":"default"},"claims":{"publicKey":"MBFFVNGFK3IA2ZXXG5DQXQNYM6TNG45PHJMJIJFVFI6YKS3XTXL3DRRK","issuer":"ADVIWF6Z3BFZNWUXJYT5NEAZZ2YX4T6NRKI3YOR3HKOSQQN7IVDGWSNO","issuedAt":"1714506509","expiresAt":null,"expired":false}},"host":{"publicKey":"NBLSJGGOETB677FQL63PWKDCVMOW4LXVI7S6WXSP55H7L5RNRKUDZKGE","lattice":"default","labels":{"hostcore.arch":"aarch64","hostcore.os":"linux","kubernetes":"true","kubernetes.hostgroup":"default","hostcore.osfamily":"unix"}}}`) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + rawResp, err := bus.Request(ctx, req) + if err != nil { + t.Fatal(err) + } + + resp := &Response{} + if err := wasmbus.Decode(rawResp, resp); err != nil { + t.Fatal(err) + } + + if want, got := "01945242-abec-71ee-f5e6-1f44eb61ad40", resp.Id; want != got { + t.Fatalf("expected %q, got %q", want, got) + } + + if !resp.Permitted { + t.Fatalf("expected allow, got deny") + } + }) + + t.Run("startProvider", func(t *testing.T) { + req := wasmbus.NewMessage("subject.test") + req.Data = []byte(`{"requestId":"01945242-b574-5094-790d-76d823e0c948","kind":"startProvider","version":"v1","request":{"providerId":"hello_world-httpserver","imageRef":"ghcr.io/wasmcloud/http-server:0.23.0","annotations":{"wasmcloud.dev/appspec":"hello-world","wasmcloud.dev/managed-by":"wadm","wasmcloud.dev/scaler":"e82e835acda294f1a3d9cb66c0dfc8619c82fa836a1e30142d5d2b607357fc86","wasmcloud.dev/spread_name":"default"},"claims":{"publicKey":"VAG3QITQQ2ODAOWB5TTQSDJ53XK3SHBEIFNK4AYJ5RKAX2UNSCAPHA5M","issuer":"ACOJJN6WUP4ODD75XEBKKTCCUJJCY5ZKQ56XVKYK4BEJWGVAOOQHZMCW","issuedAt":"1725897949","expiresAt":null,"expired":false}},"host":{"publicKey":"NBLSJGGOETB677FQL63PWKDCVMOW4LXVI7S6WXSP55H7L5RNRKUDZKGE","lattice":"default","labels":{"hostcore.arch":"aarch64","hostcore.os":"linux","kubernetes":"true","kubernetes.hostgroup":"default","hostcore.osfamily":"unix"}}}`) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + rawResp, err := bus.Request(ctx, req) + if err != nil { + t.Fatal(err) + } + + resp := &Response{} + if err := wasmbus.Decode(rawResp, resp); err != nil { + t.Fatal(err) + } + + if want, got := "01945242-b574-5094-790d-76d823e0c948", resp.Id; want != got { + t.Fatalf("expected %q, got %q", want, got) + } + + if resp.Permitted { + t.Fatalf("expected deny, got allow") + } + }) + + t.Run("performInvocation", func(t *testing.T) { + req := wasmbus.NewMessage("subject.test") + req.Data = []byte(`{"requestId":"01945244-9a84-a36c-a1b2-b722bd686bca","kind":"performInvocation","version":"v1","request":{"interface":"wrpc:http/incoming-handler@0.1.0","function":"handle","target":{"componentId":"hello_world-http_component","imageRef":"ghcr.io/wasmcloud/components/http-hello-world-rust:0.1.0","maxInstances":0,"annotations":{"wasmcloud.dev/appspec":"hello-world","wasmcloud.dev/managed-by":"wadm","wasmcloud.dev/scaler":"a648fe966cbdb0a0dee3252a416f824858bbf0c1a24be850ef632626ddbb5133","wasmcloud.dev/spread_name":"default"},"claims":{"publicKey":"MBFFVNGFK3IA2ZXXG5DQXQNYM6TNG45PHJMJIJFVFI6YKS3XTXL3DRRK","issuer":"ADVIWF6Z3BFZNWUXJYT5NEAZZ2YX4T6NRKI3YOR3HKOSQQN7IVDGWSNO","issuedAt":"1714506509","expiresAt":null,"expired":false}}},"host":{"publicKey":"NBLSJGGOETB677FQL63PWKDCVMOW4LXVI7S6WXSP55H7L5RNRKUDZKGE","lattice":"default","labels":{"hostcore.arch":"aarch64","hostcore.os":"linux","kubernetes":"true","kubernetes.hostgroup":"default","hostcore.osfamily":"unix"}}}`) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + rawResp, err := bus.Request(ctx, req) + if err != nil { + t.Fatal(err) + } + + resp := &Response{} + if err := wasmbus.Decode(rawResp, resp); err != nil { + t.Fatal(err) + } + + if want, got := "01945244-9a84-a36c-a1b2-b722bd686bca", resp.Id; want != got { + t.Fatalf("expected %q, got %q", want, got) + } + + if !resp.Permitted { + t.Fatalf("expected allow, got deny") + } + }) + + if err := s.Drain(); err != nil { + t.Fatalf("failed to drain server: %v", err) + } +} diff --git a/x/wasmbus/secrets/api.go b/x/wasmbus/secrets/api.go new file mode 100644 index 00000000..4406ab13 --- /dev/null +++ b/x/wasmbus/secrets/api.go @@ -0,0 +1,275 @@ +package secrets + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + PrefixVersion = "v1alpha1" + WasmCloudHostXkey = "WasmCloud-Host-Xkey" + WasmCloudResponseXkey = "Server-Response-Xkey" +) + +type APIv1alpha1 interface { + Get(ctx context.Context, req *GetRequest) (*GetResponse, error) +} + +type APIMock struct { + GetFunc func(ctx context.Context, req *GetRequest) (*GetResponse, error) +} + +func (a *APIMock) Get(ctx context.Context, req *GetRequest) (*GetResponse, error) { + return a.GetFunc(ctx, req) +} + +var ( + ErrInvalidServerConfig = errors.New("invalid server configuration") + + ErrSecretNotFound = newResponseError("SecretNotFound", false) + ErrInvalidRequest = newResponseError("InvalidRequest", false) + ErrInvalidHeaders = newResponseError("InvalidHeaders", false) + ErrInvalidPayload = newResponseError("InvalidPayload", false) + ErrEncryption = newResponseError("EncryptionError", false) + ErrDecryption = newResponseError("DecryptionError", false) + + ErrInvalidEntityJWT = newResponseError("InvalidEntityJWT", true) + ErrInvalidHostJWT = newResponseError("InvalidHostJWT", true) + ErrUpstream = newResponseError("UpstreamError", true) + ErrPolicy = newResponseError("PolicyError", true) + ErrOther = newResponseError("Other", true) +) + +type Error struct { + Tip string + HasMessage bool + Message string +} + +func (re Error) With(msg string) *Error { + otherError := re + otherError.Message = msg + return &otherError +} + +func (re Error) Error() string { + return re.Tip +} + +func (re *Error) UnmarshalJSON(data []byte) error { + serdeSpecial := make(map[string]string) + if err := json.Unmarshal(data, &serdeSpecial); err != nil { + var msg string + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + *re = *ErrOther.With(msg) + return nil + } + if len(serdeSpecial) != 1 { + return errors.New("couldn't parse ResponseError") + } + for k, v := range serdeSpecial { + *re = Error{Tip: k, HasMessage: v != "", Message: v} + break + } + + return nil +} + +func (re *Error) MarshalJSON() ([]byte, error) { + if re == nil { + return nil, nil + } + + if !re.HasMessage { + return json.Marshal(re.Tip) + } + + serdeSpecial := make(map[string]string) + serdeSpecial[re.Tip] = re.Message + + return json.Marshal(serdeSpecial) +} + +func newResponseError(tip string, hasMessage bool) *Error { + return &Error{Tip: tip, HasMessage: hasMessage} +} + +type applicationContextPolicy struct { + Type string `json:"type"` + Properties json.RawMessage `json:"properties"` +} + +func (a ApplicationContext) PolicyProperties() (json.RawMessage, error) { + policy := &applicationContextPolicy{} + err := json.Unmarshal([]byte(a.Policy), policy) + return policy.Properties, err +} + +type GetRequest struct { + Key string `json:"key"` + Field string `json:"field"` + Version string `json:"version,omitempty"` + Context Context `json:"context"` + // NOTE(lxf): HostPubKey is not part of the actual request. + // filled in by middleware. + HostPubKey string `json:"-"` +} + +// NOTE(lxf): The way we return errors here is far from optimal... +type GetResponse struct { + Secret *SecretValue `json:"secret,omitempty"` + Error *Error `json:"error,omitempty"` +} + +type SecretValue struct { + Version string `json:"version,omitempty"` + StringSecret string `json:"string_secret,omitempty"` + BinarySecret BinarySecret `json:"binary_secret,omitempty"` +} + +// NOTE(lxf): This is a rust serde special... +type BinarySecret []uint8 + +func (u BinarySecret) MarshalJSON() ([]byte, error) { + var result string + if u == nil { + return nil, nil + } + + result = strings.Join(strings.Fields(fmt.Sprintf("%d", u)), ",") + return []byte(result), nil +} + +type Context struct { + /// The application the entity belongs to. + /// TODO: should this also be a JWT, but signed by the host? + Application *ApplicationContext `json:"application,omitempty"` + /// The component or provider's signed JWT. + EntityJwt string `json:"entity_jwt"` + /// The host's signed JWT. + HostJwt string `json:"host_jwt"` +} + +func (ctx Context) IsValid() *Error { + if _, _, err := ctx.EntityCapabilities(); err != nil { + return err + } + + if _, _, err := ctx.HostCapabilities(); err != nil { + return err + } + + return nil +} + +func (ctx Context) EntityCapabilities() (*WasCap, *ComponentClaims, *Error) { + token, err := jwt.ParseWithClaims(ctx.EntityJwt, &WasCap{}, KeyPairFromIssuer()) + if err != nil { + return nil, nil, ErrInvalidEntityJWT.With(err.Error()) + } + + wasCap, ok := token.Claims.(*WasCap) + if !ok { + return nil, nil, ErrInvalidEntityJWT.With("not wascap") + } + + compCap := &ComponentClaims{} + if err := json.Unmarshal(wasCap.Was, compCap); err != nil { + return nil, nil, ErrInvalidEntityJWT.With(err.Error()) + } + + return wasCap, compCap, nil +} + +func (ctx Context) HostCapabilities() (*WasCap, *HostClaims, *Error) { + token, err := jwt.ParseWithClaims(ctx.HostJwt, &WasCap{}, KeyPairFromIssuer()) + if err != nil { + return nil, nil, ErrInvalidHostJWT.With(err.Error()) + } + + wasCap, ok := token.Claims.(*WasCap) + if !ok { + return nil, nil, ErrInvalidHostJWT.With("not wascap") + } + + hostCap := &HostClaims{} + if err := json.Unmarshal(wasCap.Was, hostCap); err != nil { + return nil, nil, ErrInvalidHostJWT.With(err.Error()) + } + + return wasCap, hostCap, nil +} + +type ComponentClaims struct { + jwt.RegisteredClaims + + /// A descriptive name for this component, should not include version information or public key + Name string `json:"name"` + /// A hash of the module's bytes as they exist without the embedded signature. This is stored so wascap + /// can determine if a WebAssembly module's bytecode has been altered after it was signed + ModuleHash string `json:"hash"` + + /// List of arbitrary string tags associated with the claims + Tags []string `json:"tags"` + + /// Indicates a monotonically increasing revision number. Optional. + Rev int32 `json:"rev"` + + /// Indicates a human-friendly version string + Ver string `json:"ver"` + + /// An optional, code-friendly alias that can be used instead of a public key or + /// OCI reference for invocations + CallAlias string `json:"call_alias"` + + /// Indicates whether this module is a capability provider + Provider bool `json:"prov"` +} + +type CapabilityProviderClaims struct { + /// A descriptive name for the capability provider + Name string `json:"name"` + /// A human-readable string identifying the vendor of this provider (e.g. Redis or Cassandra or NATS etc) + Vendor string `json:"vendor"` + /// Indicates a monotonically increasing revision number. Optional. + Rev int32 `json:"rev"` + /// Indicates a human-friendly version string. Optional. + Ver string `json:"ver"` + /// If the provider chooses, it can supply a JSON schma that describes its expected link configuration + ConfigSchema json.RawMessage `json:"config_schema,omitempty"` + /// The file hashes that correspond to the achitecture-OS target triples for this provider. + TargetHashes map[string]string `json:"target_hashes"` +} + +type HostClaims struct { + /// Optional friendly descriptive name for the host + Name string `json:"name"` + /// Optional labels for the host + Labels map[string]string `json:"labels"` +} + +type WasCap struct { + jwt.RegisteredClaims + + /// Custom jwt claims in the `wascap` namespace + Was json.RawMessage `json:"wascap,omitempty"` + + /// Internal revision number used to aid in parsing and validating claims + Revision int32 `json:"wascap_revision,omitempty"` +} + +func (w WasCap) ParseCapability(dst interface{}) error { + return json.Unmarshal(w.Was, dst) +} + +type ApplicationContext struct { + Policy string `json:"policy"` + Name string `json:"name"` +} diff --git a/x/wasmbus/secrets/api_test.go b/x/wasmbus/secrets/api_test.go new file mode 100644 index 00000000..ae91966b --- /dev/null +++ b/x/wasmbus/secrets/api_test.go @@ -0,0 +1,68 @@ +package secrets + +import ( + "encoding/json" + "testing" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/nats-io/nkeys" +) + +func TestJWTClaims(t *testing.T) { + claims := jwt.RegisteredClaims{ + // A usual scenario is to set the expiration time relative to the current time + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + Issuer: "test", + Subject: "somebody", + ID: "1", + Audience: []string{"somebody_else"}, + } + + kp, err := nkeys.CreateAccount() + if err != nil { + t.Fatal(err) + } + + token := jwt.NewWithClaims(SigningMethodEd25519, claims) + _, err = token.SignedString(kp) + if err != nil { + t.Fatal(err) + } + + validJWT := "eyJ0eXAiOiJqd3QiLCJhbGciOiJFZDI1NTE5In0.eyJqdGkiOiJTakI1Zm05NzRTanU5V01nVFVjaHNiIiwiaWF0IjoxNjQ0ODQzNzQzLCJpc3MiOiJBQ09KSk42V1VQNE9ERDc1WEVCS0tUQ0NVSkpDWTVaS1E1NlhWS1lLNEJFSldHVkFPT1FIWk1DVyIsInN1YiI6Ik1CQ0ZPUE02SlcyQVBKTFhKRDNaNU80Q043Q1BZSjJCNEZUS0xKVVI1WVI1TUlUSVU3SEQzV0Q1Iiwid2FzY2FwIjp7Im5hbWUiOiJFY2hvIiwiaGFzaCI6IjRDRUM2NzNBN0RDQ0VBNkE0MTY1QkIxOTU4MzJDNzkzNjQ3MUNGN0FCNDUwMUY4MzdGOEQ2NzlGNDQwMEJDOTciLCJ0YWdzIjpbXSwiY2FwcyI6WyJ3YXNtY2xvdWQ6aHR0cHNlcnZlciJdLCJyZXYiOjQsInZlciI6IjAuMy40IiwicHJvdiI6ZmFsc2V9fQ.ZWyD6VQqzaYM1beD2x9Fdw4o_Bavy3ZG703Eg4cjhyJwUKLDUiVPVhqHFE6IXdV4cW6j93YbMT6VGq5iBDWmAg" + t.Run("ParseWithClaims", func(t *testing.T) { + _, err := jwt.ParseWithClaims(validJWT, &jwt.RegisteredClaims{}, KeyPairFromIssuer()) + if err != nil { + t.Error(err) + } + }) + + t.Run("ComponentClaims", func(t *testing.T) { + token, err := jwt.ParseWithClaims(validJWT, &WasCap{}, KeyPairFromIssuer()) + if err != nil { + t.Fatal(err) + } + + var componentClaims ComponentClaims + wasCap := token.Claims.(*WasCap) + err = wasCap.ParseCapability(&componentClaims) + if err != nil { + t.Error(err) + } + }) +} + +func TestContext(t *testing.T) { + raw := `{"application":{"policy":"","name":"appname"},"entity_jwt":"eyJ0eXAiOiJqd3QiLCJhbGciOiJFZDI1NTE5In0.eyJqdGkiOiJxdmVOakZjcW51dWhQaVJUMkU1YWJXIiwiaWF0IjoxNzIxODM0ODg5LCJpc3MiOiJBQk9HQjRXNURPWDNVTzNSVldXUUdZU01WWEhSUFFZWFZaUDVVNFZGTUpEQ1lDV0FSN1M1Q1lNTyIsInN1YiI6Ik1DNUNDNFVENUxQRFo0QzdaTkFFQTRPWlEzQkVGTFNWUTc0MlczVEVUM09OS1M0RFJCVk5NNUlDIiwid2FzY2FwIjp7Im5hbWUiOiJodHRwLWhlbGxvLXdvcmxkIiwiaGFzaCI6IkNFOTAxOTJDOTlDMEIyQzYwOEIyRTJDQjYxOUE5MjUxRkI2ODE4NTZDMTU2ODFCMUJDRDYyRUVEQTJENTEyOEUiLCJ0YWdzIjpbIndhc21jbG91ZC5jb20vZXhwZXJpbWVudGFsIl0sInJldiI6MCwidmVyIjoiMC4xLjAiLCJwcm92IjpmYWxzZX0sIndhc2NhcF9yZXZpc2lvbiI6M30.8awbkvrBnRKLpz88s7GXYCW0onpKf_nNfsj7pXhCyvq8pm4y2IotrIPCdBvWqDvDouX4VAM6DQQUHuI-VdKYAA","host_jwt":"eyJ0eXAiOiJqd3QiLCJhbGciOiJFZDI1NTE5In0.eyJqdGkiOiJuTGdta2Zud2p2Nkw1R28xSlNUdU0zIiwiaWF0IjoxNzIyMDE5OTk1LCJpc3MiOiJBQzNGU0IzT0VSQ1IzVU00WVNWUjJUQURFVlFWUTNITVpQQUtHS082QkNRSTRSNEFITFY2SVhSMiIsInN1YiI6Ik5ETlBUM0QzWVNUQzVKR0g2QVBKUDZBTVZYUVk2QklETVVXWkdTU1FXMjZWSjNINFBDRjJTU0ZSIiwid2FzY2FwIjp7Im5hbWUiOiJkZWxpY2F0ZS1icmVlemUtOTc4NSIsImxhYmVscyI6eyJzZWxmX3NpZ25lZCI6InRydWUifX0sIndhc2NhcF9yZXZpc2lvbiI6M30.5LM_GOpo-6qg0kDrIP_jswI_ZQfOILzHT-FHixvUeAf-1isamLg81S-rb84w6topfvevI6quyV3b-uHZt6q9BQ"}` + ctx := &Context{} + err := json.Unmarshal([]byte(raw), ctx) + if err != nil { + t.Fatal(err) + } + if err := ctx.IsValid(); err != nil { + t.Fatal(err) + } +} diff --git a/x/wasmbus/secrets/ed25519.go b/x/wasmbus/secrets/ed25519.go new file mode 100644 index 00000000..f4cda4c6 --- /dev/null +++ b/x/wasmbus/secrets/ed25519.go @@ -0,0 +1,64 @@ +package secrets + +import ( + "errors" + "fmt" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/nats-io/nkeys" +) + +var ErrEd25519Verification = errors.New("ed25519: verification error") + +// SigningMethodNats implements nkeys based ed25519 signing and verification. +type SigningMethodNats struct{} + +// Specific instance for Ed25519 - nkeys edition +var ( + SigningMethodEd25519 *SigningMethodNats +) + +func init() { + SigningMethodEd25519 = &SigningMethodNats{} + jwt.RegisterSigningMethod(SigningMethodEd25519.Alg(), func() jwt.SigningMethod { + return SigningMethodEd25519 + }) +} + +func (m *SigningMethodNats) Alg() string { + return "Ed25519" +} + +// Verify implements token verification for the SigningMethod. +func (m *SigningMethodNats) Verify(signingString string, sig []byte, key interface{}) error { + var ed25519Key nkeys.KeyPair + var ok bool + + if ed25519Key, ok = key.(nkeys.KeyPair); !ok { + return fmt.Errorf("%w: Ed25519 sign expects nkeys.KeyPair", jwt.ErrInvalidKeyType) + } + + return ed25519Key.Verify([]byte(signingString), sig) +} + +// Sign implements token signing for the SigningMethod. +func (m *SigningMethodNats) Sign(signingString string, key interface{}) ([]byte, error) { + var ed25519Key nkeys.KeyPair + var ok bool + + if ed25519Key, ok = key.(nkeys.KeyPair); !ok { + return nil, fmt.Errorf("%w: Ed25519 sign expects nkeys.KeyPair", jwt.ErrInvalidKeyType) + } + + return ed25519Key.Sign([]byte(signingString)) +} + +func KeyPairFromIssuer() func(token *jwt.Token) (interface{}, error) { + return func(token *jwt.Token) (interface{}, error) { + iss, err := token.Claims.GetIssuer() + if err != nil { + return nil, err + } + return nkeys.FromPublicKey(iss) + } +} diff --git a/x/wasmbus/secrets/ed25519_test.go b/x/wasmbus/secrets/ed25519_test.go new file mode 100644 index 00000000..7849eef2 --- /dev/null +++ b/x/wasmbus/secrets/ed25519_test.go @@ -0,0 +1,53 @@ +package secrets + +import ( + "testing" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/nats-io/nkeys" +) + +func TestEd25519(t *testing.T) { + t.Run("SignAndVerify", func(t *testing.T) { + kp, err := nkeys.CreateAccount() + if err != nil { + t.Fatal(err) + } + pubKey, err := kp.PublicKey() + if err != nil { + t.Fatal(err) + } + + claims := jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + Issuer: pubKey, + Subject: "somebody", + ID: "1", + Audience: []string{"somebody_else"}, + } + + token := jwt.NewWithClaims(SigningMethodEd25519, claims) + signedJWT, err := token.SignedString(kp) + if err != nil { + t.Fatal(err) + } + + // try opening the signed jwt + _, err = jwt.ParseWithClaims(signedJWT, &jwt.RegisteredClaims{}, KeyPairFromIssuer()) + if err != nil { + t.Error(err) + } + }) + + t.Run("Verify", func(t *testing.T) { + // jwt from wasmcloud host codebase + validJWT := "eyJ0eXAiOiJqd3QiLCJhbGciOiJFZDI1NTE5In0.eyJqdGkiOiJTakI1Zm05NzRTanU5V01nVFVjaHNiIiwiaWF0IjoxNjQ0ODQzNzQzLCJpc3MiOiJBQ09KSk42V1VQNE9ERDc1WEVCS0tUQ0NVSkpDWTVaS1E1NlhWS1lLNEJFSldHVkFPT1FIWk1DVyIsInN1YiI6Ik1CQ0ZPUE02SlcyQVBKTFhKRDNaNU80Q043Q1BZSjJCNEZUS0xKVVI1WVI1TUlUSVU3SEQzV0Q1Iiwid2FzY2FwIjp7Im5hbWUiOiJFY2hvIiwiaGFzaCI6IjRDRUM2NzNBN0RDQ0VBNkE0MTY1QkIxOTU4MzJDNzkzNjQ3MUNGN0FCNDUwMUY4MzdGOEQ2NzlGNDQwMEJDOTciLCJ0YWdzIjpbXSwiY2FwcyI6WyJ3YXNtY2xvdWQ6aHR0cHNlcnZlciJdLCJyZXYiOjQsInZlciI6IjAuMy40IiwicHJvdiI6ZmFsc2V9fQ.ZWyD6VQqzaYM1beD2x9Fdw4o_Bavy3ZG703Eg4cjhyJwUKLDUiVPVhqHFE6IXdV4cW6j93YbMT6VGq5iBDWmAg" + _, err := jwt.ParseWithClaims(validJWT, &jwt.RegisteredClaims{}, KeyPairFromIssuer()) + if err != nil { + t.Error(err) + } + }) +} diff --git a/x/wasmbus/secrets/server.go b/x/wasmbus/secrets/server.go new file mode 100644 index 00000000..ae8e0d88 --- /dev/null +++ b/x/wasmbus/secrets/server.go @@ -0,0 +1,121 @@ +package secrets + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/nats-io/nkeys" + "go.wasmcloud.dev/x/wasmbus" +) + +type Server struct { + *wasmbus.Server + Name string + api APIv1alpha1 + key KeyPair + pubKey string +} + +type ServerOption func(*Server) error + +func EphemeralKey() (KeyPair, error) { + return nkeys.CreateCurveKeys() +} + +func KeyPairFromSeed(seed []byte) (KeyPair, error) { + return nkeys.FromSeed(seed) +} + +type KeyPair = nkeys.KeyPair + +func NewServer(bus wasmbus.Bus, name string, kp KeyPair, api APIv1alpha1) *Server { + server := &Server{ + Server: wasmbus.NewServer(bus), + Name: name, + api: api, + key: kp, + } + + return server +} + +type secretContextKey string + +const hostContextKey secretContextKey = "secret" + +func (s *Server) decodeCiphered(ctx context.Context, req *GetRequest, msg *wasmbus.Message) (context.Context, error) { + hostPubKey := msg.Header.Get(WasmCloudHostXkey) + if hostPubKey == "" { + return ctx, fmt.Errorf("%w: missing host public key", ErrInvalidHeaders) + } + + decrypted, err := s.key.Open(msg.Data, hostPubKey) + if err != nil { + return ctx, fmt.Errorf("%w: %s", ErrDecryption, err) + } + + if err := json.Unmarshal(decrypted, req); err != nil { + return ctx, err + } + + ctx = context.WithValue(ctx, hostContextKey, hostPubKey) + return ctx, nil +} + +func (s *Server) encodeCiphered(ctx context.Context, replyTo string, resp *GetResponse) (*wasmbus.Message, error) { + hostPubKey := ctx.Value(hostContextKey).(string) + + responseKey, err := nkeys.CreateCurveKeys() + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrEncryption, err) + } + + ephemeralPubKey, err := responseKey.PublicKey() + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrEncryption, err) + } + + msg, err := wasmbus.Encode(replyTo, resp) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrEncryption, err) + } + + msg.Data, err = responseKey.Seal(msg.Data, hostPubKey) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrEncryption, err) + } + + msg.Header.Add(WasmCloudResponseXkey, ephemeralPubKey) + + return msg, nil +} + +func (s *Server) serveXkey(ctx context.Context, msg *wasmbus.Message) error { + resp := wasmbus.NewMessage(msg.Reply) + resp.Data = []byte(s.pubKey) + return s.Publish(resp) +} + +func (s *Server) Serve() error { + var err error + s.pubKey, err = s.key.PublicKey() + if err != nil { + return err + } + + if err := s.RegisterHandler(s.subject("server_xkey"), wasmbus.ServerHandlerFunc(s.serveXkey)); err != nil { + return err + } + + get := wasmbus.NewRequestHandler(GetRequest{}, GetResponse{}, s.api.Get) + get.Decode = s.decodeCiphered + get.Encode = s.encodeCiphered + return s.RegisterHandler(s.subject("get"), get) +} + +func (s *Server) subject(ids ...string) string { + parts := append([]string{wasmbus.PrefixSecrets, PrefixVersion, s.Name}, ids...) + return strings.Join(parts, ".") +} diff --git a/x/wasmbus/secrets/server_test.go b/x/wasmbus/secrets/server_test.go new file mode 100644 index 00000000..1dc45397 --- /dev/null +++ b/x/wasmbus/secrets/server_test.go @@ -0,0 +1,364 @@ +package secrets + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/nats-io/nats.go" + "go.wasmcloud.dev/x/wasmbus" + "go.wasmcloud.dev/x/wasmbus/wasmbustest" +) + +func keyPairForTest(t *testing.T) KeyPair { + t.Helper() + + kp, err := EphemeralKey() + if err != nil { + t.Fatal(err) + } + + return kp +} + +type apiMock struct { + APIMock + t *testing.T +} + +func TestServerXKey(t *testing.T) { + defer wasmbustest.MustStartNats(t)() + + kp := keyPairForTest(t) + + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatalf("failed to connect to nats: %v", err) + } + + bus := wasmbus.NewNatsBus(nc) + mock := &apiMock{} + s := NewServer(bus, "test", kp, mock) + if err := s.Serve(); err != nil { + t.Fatalf("failed to start server: %v", err) + } + + serverPubKey, err := kp.PublicKey() + if err != nil { + t.Fatal(err) + } + + req := wasmbus.NewMessage(fmt.Sprintf("%s.%s.%s.server_xkey", wasmbus.PrefixSecrets, PrefixVersion, "test")) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + rawResp, err := bus.Request(ctx, req) + if err != nil { + t.Fatal(err) + } + if want, got := serverPubKey, string(rawResp.Data); want != got { + t.Errorf("want %v, got %v", want, got) + } +} + +func TestGet(t *testing.T) { + defer wasmbustest.MustStartNats(t)() + + kp := keyPairForTest(t) + + nc, err := nats.Connect(nats.DefaultURL) + if err != nil { + t.Fatalf("failed to connect to nats: %v", err) + } + + bus := wasmbus.NewNatsBus(nc) + mock := &apiMock{} + s := NewServer(bus, "test", kp, mock) + if err := s.Serve(); err != nil { + t.Fatalf("failed to start server: %v", err) + } + + // log errors + go func() { + for c := range s.ErrorStream() { + t.Log(c) + } + }() + + serverPubKey, err := kp.PublicKey() + if err != nil { + t.Fatal(err) + } + + reqCtx := Context{ + Application: &ApplicationContext{ + Name: "appname", + }, + EntityJwt: "eyJ0eXAiOiJqd3QiLCJhbGciOiJFZDI1NTE5In0.eyJqdGkiOiJxdmVOakZjcW51dWhQaVJUMkU1YWJXIiwiaWF0IjoxNzIxODM0ODg5LCJpc3MiOiJBQk9HQjRXNURPWDNVTzNSVldXUUdZU01WWEhSUFFZWFZaUDVVNFZGTUpEQ1lDV0FSN1M1Q1lNTyIsInN1YiI6Ik1DNUNDNFVENUxQRFo0QzdaTkFFQTRPWlEzQkVGTFNWUTc0MlczVEVUM09OS1M0RFJCVk5NNUlDIiwid2FzY2FwIjp7Im5hbWUiOiJodHRwLWhlbGxvLXdvcmxkIiwiaGFzaCI6IkNFOTAxOTJDOTlDMEIyQzYwOEIyRTJDQjYxOUE5MjUxRkI2ODE4NTZDMTU2ODFCMUJDRDYyRUVEQTJENTEyOEUiLCJ0YWdzIjpbIndhc21jbG91ZC5jb20vZXhwZXJpbWVudGFsIl0sInJldiI6MCwidmVyIjoiMC4xLjAiLCJwcm92IjpmYWxzZX0sIndhc2NhcF9yZXZpc2lvbiI6M30.8awbkvrBnRKLpz88s7GXYCW0onpKf_nNfsj7pXhCyvq8pm4y2IotrIPCdBvWqDvDouX4VAM6DQQUHuI-VdKYAA", + HostJwt: "eyJ0eXAiOiJqd3QiLCJhbGciOiJFZDI1NTE5In0.eyJqdGkiOiJuTGdta2Zud2p2Nkw1R28xSlNUdU0zIiwiaWF0IjoxNzIyMDE5OTk1LCJpc3MiOiJBQzNGU0IzT0VSQ1IzVU00WVNWUjJUQURFVlFWUTNITVpQQUtHS082QkNRSTRSNEFITFY2SVhSMiIsInN1YiI6Ik5ETlBUM0QzWVNUQzVKR0g2QVBKUDZBTVZYUVk2QklETVVXWkdTU1FXMjZWSjNINFBDRjJTU0ZSIiwid2FzY2FwIjp7Im5hbWUiOiJkZWxpY2F0ZS1icmVlemUtOTc4NSIsImxhYmVscyI6eyJzZWxmX3NpZ25lZCI6InRydWUifX0sIndhc2NhcF9yZXZpc2lvbiI6M30.5LM_GOpo-6qg0kDrIP_jswI_ZQfOILzHT-FHixvUeAf-1isamLg81S-rb84w6topfvevI6quyV3b-uHZt6q9BQ", + } + + tt := []struct { + name string + req *GetRequest + getFunc func(context.Context, *GetRequest) (*GetResponse, error) + validate func(*testing.T, *GetResponse) + }{ + { + name: "get string", + req: &GetRequest{ + Key: "key", + }, + getFunc: func(ctx context.Context, r *GetRequest) (*GetResponse, error) { + if want, got := "key", r.Key; want != got { + mock.t.Errorf("want %v, got %v", want, got) + } + return &GetResponse{ + Secret: &SecretValue{ + StringSecret: "hunter2", + }, + }, nil + }, + validate: func(t *testing.T, resp *GetResponse) { + if want, got := "hunter2", resp.Secret.StringSecret; want != got { + t.Errorf("want %v, got %v", want, got) + } + }, + }, + { + name: "get binary", + req: &GetRequest{ + Key: "keybin", + }, + getFunc: func(ctx context.Context, r *GetRequest) (*GetResponse, error) { + if want, got := "keybin", r.Key; want != got { + mock.t.Errorf("want %v, got %v", want, got) + } + return &GetResponse{ + Secret: &SecretValue{ + BinarySecret: BinarySecret([]byte("hunter2")), + }, + }, nil + }, + validate: func(t *testing.T, resp *GetResponse) { + if want, got := "hunter2", string(resp.Secret.BinarySecret); want != got { + t.Errorf("want %v, got %v", want, got) + } + }, + }, + { + name: "validate context", + req: &GetRequest{ + Key: "test", + Context: reqCtx, + }, + getFunc: func(ctx context.Context, r *GetRequest) (*GetResponse, error) { + if err := r.Context.IsValid(); err != nil { + mock.t.Errorf("context validation failed: %v", err) + } + return &GetResponse{}, nil + }, + validate: func(t *testing.T, resp *GetResponse) { + }, + }, + { + name: "internal error", + req: &GetRequest{ + Key: "test", + }, + getFunc: func(ctx context.Context, r *GetRequest) (*GetResponse, error) { + return &GetResponse{ + Error: ErrOther, + }, nil + }, + validate: func(t *testing.T, resp *GetResponse) { + if want, got := "Other", resp.Error.Tip; want != got { + t.Errorf("want %v, got %v", want, got) + } + }, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + mock.t = t + mock.GetFunc = tc.getFunc + hostKey := keyPairForTest(t) + hostPubKey, err := hostKey.PublicKey() + if err != nil { + t.Fatal(err) + } + + rawSreq, err := wasmbus.EncodeMimetype(tc.req, "application/json") + if err != nil { + t.Fatal(err) + } + + req := wasmbus.NewMessage(fmt.Sprintf("%s.%s.%s.get", wasmbus.PrefixSecrets, PrefixVersion, "test")) + req.Header.Add(WasmCloudHostXkey, hostPubKey) + req.Data, err = hostKey.Seal(rawSreq, serverPubKey) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + rawResp, err := bus.Request(ctx, req) + if err != nil { + t.Fatal(err) + } + decrypted, err := hostKey.Open(rawResp.Data, rawResp.Header.Get(WasmCloudResponseXkey)) + if err != nil { + t.Fatal(err) + } + + resp := &GetResponse{} + if err := json.Unmarshal(decrypted, resp); err != nil { + t.Fatal(err) + } + + tc.validate(t, resp) + }) + } + + if err := s.Drain(); err != nil { + t.Fatalf("failed to drain server: %v", err) + } +} + +/* + kp := keyPairForTest(t) + hostPubKey, err := kp.PublicKey() + if err != nil { + t.Fatal(err) + } + + reqCtx := Context{ + Application: &ApplicationContext{ + Name: "appname", + }, + EntityJwt: "eyJ0eXAiOiJqd3QiLCJhbGciOiJFZDI1NTE5In0.eyJqdGkiOiJxdmVOakZjcW51dWhQaVJUMkU1YWJXIiwiaWF0IjoxNzIxODM0ODg5LCJpc3MiOiJBQk9HQjRXNURPWDNVTzNSVldXUUdZU01WWEhSUFFZWFZaUDVVNFZGTUpEQ1lDV0FSN1M1Q1lNTyIsInN1YiI6Ik1DNUNDNFVENUxQRFo0QzdaTkFFQTRPWlEzQkVGTFNWUTc0MlczVEVUM09OS1M0RFJCVk5NNUlDIiwid2FzY2FwIjp7Im5hbWUiOiJodHRwLWhlbGxvLXdvcmxkIiwiaGFzaCI6IkNFOTAxOTJDOTlDMEIyQzYwOEIyRTJDQjYxOUE5MjUxRkI2ODE4NTZDMTU2ODFCMUJDRDYyRUVEQTJENTEyOEUiLCJ0YWdzIjpbIndhc21jbG91ZC5jb20vZXhwZXJpbWVudGFsIl0sInJldiI6MCwidmVyIjoiMC4xLjAiLCJwcm92IjpmYWxzZX0sIndhc2NhcF9yZXZpc2lvbiI6M30.8awbkvrBnRKLpz88s7GXYCW0onpKf_nNfsj7pXhCyvq8pm4y2IotrIPCdBvWqDvDouX4VAM6DQQUHuI-VdKYAA", + HostJwt: "eyJ0eXAiOiJqd3QiLCJhbGciOiJFZDI1NTE5In0.eyJqdGkiOiJuTGdta2Zud2p2Nkw1R28xSlNUdU0zIiwiaWF0IjoxNzIyMDE5OTk1LCJpc3MiOiJBQzNGU0IzT0VSQ1IzVU00WVNWUjJUQURFVlFWUTNITVpQQUtHS082QkNRSTRSNEFITFY2SVhSMiIsInN1YiI6Ik5ETlBUM0QzWVNUQzVKR0g2QVBKUDZBTVZYUVk2QklETVVXWkdTU1FXMjZWSjNINFBDRjJTU0ZSIiwid2FzY2FwIjp7Im5hbWUiOiJkZWxpY2F0ZS1icmVlemUtOTc4NSIsImxhYmVscyI6eyJzZWxmX3NpZ25lZCI6InRydWUifX0sIndhc2NhcF9yZXZpc2lvbiI6M30.5LM_GOpo-6qg0kDrIP_jswI_ZQfOILzHT-FHixvUeAf-1isamLg81S-rb84w6topfvevI6quyV3b-uHZt6q9BQ", + } + + tests := map[string]struct { + plainText bool + req Request + protocolError bool + hostKey string + getFunc func(ctx context.Context, r *Request) (*SecretValue, error) + checkResponse func(*testing.T, Response) + }{ + "blank": { + plainText: true, + protocolError: true, + }, + "happyPath": { + req: Request{ + Key: "secret", + Context: reqCtx, + }, + }, + "upstreamError": { + req: Request{ + Key: "secret", + Context: reqCtx, + }, + protocolError: true, + getFunc: func(context.Context, *Request) (*SecretValue, error) { + return nil, ErrUpstream.With("boom") + }, + checkResponse: func(t *testing.T, resp Response) { + if want, got := ErrUpstream.Error(), resp.Error.Error(); want != got { + t.Errorf("want %v, got %v", want, got) + } + }, + }, + "badSecret": { + req: Request{ + Key: "secret", + Context: reqCtx, + }, + hostKey: "badkey", + protocolError: true, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.getFunc != nil { + handler.getFunc = test.getFunc + } else { + handler.getFunc = basicGetFunc + } + rawData, err := json.Marshal(&test.req) + if err != nil { + t.Fatal(err) + } + + rawReq := nats.NewMsg(server.subjectMapper.SecretsSubject() + ".get") + + if !test.plainText { + sealedData, err := kp.Seal(rawData, serverPubKey) + if err != nil { + t.Fatal(err) + } + + rawReq.Data = sealedData + hostKey := hostPubKey + if test.hostKey != "" { + hostKey = test.hostKey + } + rawReq.Header.Add(WasmCloudHostXkey, hostKey) + } + + rawReply, err := nc.RequestMsg(rawReq, time.Second) + if err != nil { + t.Fatal(err) + } + + var resp Response + + // the presence of the response header indicates if this is an encrypted response or not + // plain responses are protocol errors + responseKey := rawReply.Header.Get(WasmCloudResponseXkey) + if test.protocolError { + if responseKey != "" { + t.Error("saw encryption header on protocol error") + } + + if err := json.Unmarshal(rawReply.Data, &resp); err != nil { + t.Fatal(err) + } + + if resp.Error == nil { + t.Fatal("Expected an error but got none") + } + + if test.checkResponse != nil { + test.checkResponse(t, resp) + } + return + } + + if !test.protocolError && responseKey == "" { + t.Error("missing encryption header") + } + + rawResponse, err := kp.Open(rawReply.Data, responseKey) + if err != nil { + t.Fatal(err) + } + if err := json.Unmarshal(rawResponse, &resp); err != nil { + t.Fatal(err) + } + + if test.checkResponse != nil { + test.checkResponse(t, resp) + } else { + basicCheckResponse(t, resp) + } + }) + } +} +*/ diff --git a/x/wasmbus/server.go b/x/wasmbus/server.go index f0195ce4..50c02c23 100644 --- a/x/wasmbus/server.go +++ b/x/wasmbus/server.go @@ -18,9 +18,6 @@ type ServerError struct { // See `AnyServerHandler` for more information. type Server struct { Bus - // Lattice is an informative field containing the lattice name. - // It is NOT used when manipulating subjects. - Lattice string // ContextFunc is a function that returns a new context for each message. // Defaults to `context.Background`. ContextFunc func() context.Context @@ -31,10 +28,9 @@ type Server struct { } // NewServer returns a new server instance. -func NewServer(bus Bus, lattice string) *Server { +func NewServer(bus Bus) *Server { return &Server{ Bus: bus, - Lattice: lattice, ContextFunc: func() context.Context { return context.Background() }, errorStream: make(chan *ServerError), } @@ -121,15 +117,34 @@ func NewRequestHandler[T any, Y any](req T, resp Y, handler func(context.Context type RequestHandler[T any, Y any] struct { Request T Response Y + Decode func(context.Context, *T, *Message) (context.Context, error) + Encode func(context.Context, string, *Y) (*Message, error) PreRequest func(context.Context, *T, *Message) error PostRequest func(context.Context, *Y, *Message) error Handler func(context.Context, *T) (*Y, error) } +func (s *RequestHandler[T, Y]) decode(ctx context.Context, req *T, msg *Message) (context.Context, error) { + if s.Decode != nil { + return s.Decode(ctx, req, msg) + } + return ctx, Decode(msg, req) +} + +func (s *RequestHandler[T, Y]) encode(ctx context.Context, subject string, resp *Y) (*Message, error) { + if s.Encode != nil { + return s.Encode(ctx, subject, resp) + } + return Encode(subject, resp) +} + // HandleMessage implements the `AnyServerHandler` interface. func (s *RequestHandler[T, Y]) HandleMessage(ctx context.Context, msg *Message) error { + var err error + req := s.Request - err := Decode(msg, &req) + + ctx, err = s.decode(ctx, &req, msg) if err != nil { return fmt.Errorf("%w: %s", ErrDecode, err) } @@ -145,7 +160,7 @@ func (s *RequestHandler[T, Y]) HandleMessage(ctx context.Context, msg *Message) return fmt.Errorf("%w: %s", ErrOperation, err) } - rawResp, err := Encode(msg.Reply, resp) + rawResp, err := s.encode(ctx, msg.Reply, resp) if err != nil { return fmt.Errorf("%w: %s", ErrEncode, err) } @@ -163,3 +178,57 @@ func (s *RequestHandler[T, Y]) HandleMessage(ctx context.Context, msg *Message) return nil } + +// TypedHandler is a higher-level abstraction that can be used to register handlers for specific types. +// It uses a `TypeExtractor` function to extract the type from the message. +// Usefull when you want to handle different types of messages with different handlers based on a json field inside the message. +type TypedHandler struct { + extractor TypeExtractor + handlers map[string]AnyServerHandler + lock sync.Mutex +} + +// TypeExtractor is a function that extracts a type name from a message. +type TypeExtractor func(ctx context.Context, msg *Message) (string, error) + +// NewTypedHandler returns a new typed handler instance. +func NewTypedHandler(extractor TypeExtractor) *TypedHandler { + return &TypedHandler{extractor: extractor, handlers: make(map[string]AnyServerHandler)} +} + +// HandleMessage implements the `AnyServerHandler` interface. +func (h *TypedHandler) HandleMessage(ctx context.Context, msg *Message) error { + if h.extractor == nil { + return fmt.Errorf("%w: no type extractor", ErrOperation) + } + + kind, err := h.extractor(ctx, msg) + if err != nil { + return fmt.Errorf("%w: %s", ErrOperation, err) + } + + h.lock.Lock() + handler, ok := h.handlers[kind] + h.lock.Unlock() + + if !ok { + return fmt.Errorf("%w: no handler for type %s", ErrOperation, kind) + } + + return handler.HandleMessage(ctx, msg) +} + +// RegisterType registers a handler for a given type. +// The handler will be called when a message with the given type is received, after the type is extracted by the `TypeExtractor`. +func (h *TypedHandler) RegisterType(kind string, handler AnyServerHandler) error { + h.lock.Lock() + defer h.lock.Unlock() + + if _, ok := h.handlers[kind]; ok { + return fmt.Errorf("%w: handler for type %s already registered", ErrOperation, kind) + } + + h.handlers[kind] = handler + + return nil +} diff --git a/x/wasmbus/server_test.go b/x/wasmbus/server_test.go index 89112f01..2c3e79ff 100644 --- a/x/wasmbus/server_test.go +++ b/x/wasmbus/server_test.go @@ -18,7 +18,7 @@ func TestServerRegisterHandler(t *testing.T) { defer nc.Close() bus := NewNatsBus(nc) - server := NewServer(bus, "test") + server := NewServer(bus) err = server.RegisterHandler("test", ServerHandlerFunc(func(ctx context.Context, msg *Message) error { reply := NewMessage(msg.Reply) reply.Data = []byte("hello") @@ -52,7 +52,7 @@ func TestServerDrain(t *testing.T) { defer nc.Close() bus := NewNatsBus(nc) - server := NewServer(bus, "test") + server := NewServer(bus) checkpoint := make(chan bool) _ = server.RegisterHandler("slow", ServerHandlerFunc(func(ctx context.Context, msg *Message) error { close(checkpoint) @@ -87,7 +87,7 @@ func TestServerErrorStream(t *testing.T) { defer nc.Close() bus := NewNatsBus(nc) - server := NewServer(bus, "test") + server := NewServer(bus) bomb := errors.New("bomb") bombCh := make(chan error, 1) _ = server.RegisterHandler("bomb", ServerHandlerFunc(func(ctx context.Context, msg *Message) error { @@ -133,7 +133,7 @@ func TestRequestHandler(t *testing.T) { defer nc.Close() bus := NewNatsBus(nc) - server := NewServer(bus, "test") + server := NewServer(bus) handler := NewRequestHandler(testRequest{}, testResponse{}, func(ctx context.Context, req *testRequest) (*testResponse, error) { return &testResponse{ Hello: "world", diff --git a/x/wasmbus/wadm/server.go b/x/wasmbus/wadm/server.go index d56c29aa..00f29942 100644 --- a/x/wasmbus/wadm/server.go +++ b/x/wasmbus/wadm/server.go @@ -9,17 +9,15 @@ import ( type Server struct { *wasmbus.Server - Lattice string - api API - handlers map[string]wasmbus.AnyServerHandler + Lattice string + api API } func NewServer(bus wasmbus.Bus, lattice string, api API) *Server { return &Server{ - Server: wasmbus.NewServer(bus, lattice), - Lattice: lattice, - api: api, - handlers: make(map[string]wasmbus.AnyServerHandler), + Server: wasmbus.NewServer(bus), + Lattice: lattice, + api: api, } }