Replace CSRF token with tokens based on the user's ID.
[kraftakt.git] / app / user.go
1 package app
2
3 import (
4         "context"
5         "crypto/hmac"
6         "crypto/sha1"
7         "encoding/hex"
8         "fmt"
9         "net/http"
10         "sync"
11
12         "github.com/google/uuid"
13         legacy_context "golang.org/x/net/context"
14         "golang.org/x/oauth2"
15         "google.golang.org/appengine/datastore"
16         "google.golang.org/appengine/log"
17 )
18
19 type User struct {
20         key *datastore.Key
21
22         ID    string
23         Email string
24 }
25
26 type dbUser struct {
27         ID string
28 }
29
30 func NewUser(ctx context.Context, email string) (*User, error) {
31         var id string
32         err := datastore.RunInTransaction(ctx, func(ctx legacy_context.Context) error {
33                 key := datastore.NewKey(ctx, "User", email, 0, nil)
34
35                 var u dbUser
36                 err := datastore.Get(ctx, key, &u)
37                 if err != nil && err != datastore.ErrNoSuchEntity {
38                         return err
39                 }
40                 if err == nil {
41                         id = u.ID
42                         return nil
43                 }
44
45                 id = uuid.New().String()
46                 _, err = datastore.Put(ctx, key, &dbUser{
47                         ID: id,
48                 })
49                 return err
50         }, nil)
51         if err != nil {
52                 return nil, err
53         }
54
55         return &User{
56                 key:   datastore.NewKey(ctx, "User", email, 0, nil),
57                 ID:    id,
58                 Email: email,
59         }, nil
60 }
61
62 func UserByID(ctx context.Context, id string) (*User, error) {
63         q := datastore.NewQuery("User").Filter("ID=", id).KeysOnly()
64         keys, err := q.GetAll(ctx, nil)
65         if err != nil {
66                 return nil, fmt.Errorf("datastore.Query.GetAll(): %v", err)
67         }
68         if len(keys) != 1 {
69                 return nil, fmt.Errorf("len(keys) = %d, want 1", len(keys))
70         }
71
72         return &User{
73                 key:   keys[0],
74                 ID:    id,
75                 Email: keys[0].StringID(),
76         }, nil
77 }
78
79 func (u *User) Token(ctx context.Context, svc string) (*oauth2.Token, error) {
80         key := datastore.NewKey(ctx, "Token", svc, 0, u.key)
81
82         var tok oauth2.Token
83         if err := datastore.Get(ctx, key, &tok); err != nil {
84                 return nil, err
85         }
86
87         return &tok, nil
88 }
89
90 func (u *User) SetToken(ctx context.Context, svc string, tok *oauth2.Token) error {
91         key := datastore.NewKey(ctx, "Token", svc, 0, u.key)
92         _, err := datastore.Put(ctx, key, tok)
93         return err
94 }
95
96 func (u *User) DeleteToken(ctx context.Context, svc string) error {
97         key := datastore.NewKey(ctx, "Token", svc, 0, u.key)
98         return datastore.Delete(ctx, key)
99 }
100
101 func (u *User) OAuthClient(ctx context.Context, svc string, cfg *oauth2.Config) (*http.Client, error) {
102         key := datastore.NewKey(ctx, "Token", svc, 0, u.key)
103
104         var tok oauth2.Token
105         if err := datastore.Get(ctx, key, &tok); err != nil {
106                 return nil, err
107         }
108
109         src := cfg.TokenSource(ctx, &tok)
110         return oauth2.NewClient(ctx, &persistingTokenSource{
111                 ctx: ctx,
112                 t:   &tok,
113                 src: src,
114                 key: key,
115         }), nil
116 }
117
118 func (u *User) String() string {
119         return u.Email
120 }
121
122 func (u *User) Sign(payload string) string {
123         mac := hmac.New(sha1.New, []byte(u.ID))
124         mac.Write([]byte(payload))
125
126         return hex.EncodeToString(mac.Sum(nil))
127 }
128
129 type persistingTokenSource struct {
130         ctx context.Context
131         t   *oauth2.Token
132         src oauth2.TokenSource
133         key *datastore.Key
134
135         sync.Mutex
136 }
137
138 func (s *persistingTokenSource) Token() (*oauth2.Token, error) {
139         s.Lock()
140         defer s.Unlock()
141
142         tok, err := s.src.Token()
143         if err != nil {
144                 return nil, err
145         }
146
147         if s.t.AccessToken != tok.AccessToken ||
148                 s.t.TokenType != tok.TokenType ||
149                 s.t.RefreshToken != tok.RefreshToken ||
150                 !s.t.Expiry.Equal(tok.Expiry) {
151                 if _, err := datastore.Put(s.ctx, s.key, tok); err != nil {
152                         log.Errorf(s.ctx, "persisting OAuth token in datastore failed: %v", err)
153                 }
154         }
155
156         s.t = tok
157         return tok, nil
158 }