ddc8af80488c738c85cc9f9600155151db1bc1fa
[xonotic/xonstat.git] / xonstat / util / xs_interceptor / xs_interceptor.go
1 package main\r
2 \r
3 import "database/sql"\r
4 import "flag"\r
5 import "fmt"\r
6 import "html/template"\r
7 import "net/http"\r
8 import "os"\r
9 import "strings"\r
10 import "time"\r
11 import _ "github.com/mattn/go-sqlite3"\r
12 \r
13 // HTML templates\r
14 var templates = template.Must(template.ParseFiles("templates/landing.html"))\r
15 \r
16 func main() {\r
17         port := flag.Int("port", 6543, "Default port on which to accept requests")\r
18         url := flag.String("url", "http://localhost:6543/stats/submit", "URL to send POST requests against")\r
19         flag.Usage = usage\r
20         flag.Parse()\r
21 \r
22         if len(flag.Args()) < 1 {\r
23                 fmt.Println("Insufficient arguments: need a <command> to run. Exiting...")\r
24                 os.Exit(1)\r
25         }\r
26 \r
27         command := flag.Args()[0]\r
28         switch {\r
29         case command == "drop":\r
30                 drop_db()\r
31         case command == "create":\r
32                 create_db()\r
33         case command == "serve":\r
34                 serve(*port)\r
35         case command == "resubmit":\r
36                 resubmit(*url)\r
37         case command == "list":\r
38                 list()\r
39         default:\r
40                 fmt.Println("Unknown command! Exiting...")\r
41                 os.Exit(1)\r
42         }\r
43 }\r
44 \r
45 // override the default Usage function to show the different "commands"\r
46 // that are in the switch statement in main()\r
47 func usage() {\r
48         fmt.Fprintf(os.Stderr, "Usage of xs_interceptor:\n")\r
49         fmt.Fprintf(os.Stderr, "    xs_interceptor [options] <command>\n\n")\r
50         fmt.Fprintf(os.Stderr, "Where <command> is one of the following:\n")\r
51         fmt.Fprintf(os.Stderr, "    create   - create the requests db (sqlite3 db file)\n")\r
52         fmt.Fprintf(os.Stderr, "    drop     - remove the requests db\n")\r
53         fmt.Fprintf(os.Stderr, "    list     - lists the requests in the db\n")\r
54         fmt.Fprintf(os.Stderr, "    serve    - listens for stats requests, storing them if found\n")\r
55         fmt.Fprintf(os.Stderr, "    resubmit - resubmits the requests to another URL\n\n")\r
56         fmt.Fprintf(os.Stderr, "Where [options] is one or more of the following:\n")\r
57         fmt.Fprintf(os.Stderr, "    -port    - port number (int) to listen on for 'serve' command\n")\r
58         fmt.Fprintf(os.Stderr, "    -url     - url (string) to submit requests\n\n")\r
59 }\r
60 \r
61 // removes the requests database. it is just a file, so this is really easy.\r
62 func drop_db() {\r
63         err := os.Remove("middleman.db")\r
64 \r
65         if err != nil {\r
66                 fmt.Println("Error dropping the database middleman.db. Exiting...")\r
67                 os.Exit(1)\r
68         } else {\r
69                 fmt.Println("Dropped middleman.db successfully!")\r
70                 os.Exit(0)\r
71         }\r
72 }\r
73 \r
74 // creates the sqlite database. it's a hard-coded name because I don't see\r
75 // a need to change db names for this purpose.\r
76 func create_db() {\r
77         db, err := sql.Open("sqlite3", "./middleman.db")\r
78         defer db.Close()\r
79 \r
80         if err != nil {\r
81                 fmt.Println("Error creating the database middleman.db. Exiting...")\r
82                 fmt.Println(err)\r
83                 os.Exit(1)\r
84         } else {\r
85                 fmt.Println("Created middleman.db successfully!")\r
86         }\r
87 \r
88         _, err = db.Exec(`\r
89      CREATE TABLE requests (\r
90         request_id INTEGER PRIMARY KEY ASC, \r
91         blind_id_header TEXT, \r
92         ip_addr VARCHAR(32), \r
93         body TEXT, \r
94         bodylength int \r
95      );\r
96   `)\r
97 \r
98         if err != nil {\r
99                 fmt.Println("Error creating the table 'requests' in middleman.db. Exiting...")\r
100                 os.Exit(1)\r
101         } else {\r
102                 fmt.Println("Created table 'requests' successfully!")\r
103         }\r
104 }\r
105 \r
106 // an HTTP server that responds to two types of URLs: stats submissions (which it records)\r
107 // and everything else, which receive a down-page\r
108 func serve(port int) {\r
109         requests := 0\r
110 \r
111         // routing\r
112         http.HandleFunc("/", defaultHandler)\r
113         http.HandleFunc("/stats/submit", makeSubmitHandler(requests))\r
114         http.Handle("/m/", http.StripPrefix("/m/", http.FileServer(http.Dir("m"))))\r
115 \r
116         // serving\r
117         fmt.Printf("Serving on port %d...\n", port)\r
118         addr := fmt.Sprintf(":%d", port)\r
119   for true {\r
120     http.ListenAndServe(addr, nil)\r
121     time.Sleep(100*time.Millisecond)\r
122   }\r
123 }\r
124 \r
125 // intercepts all URLs, displays a landing page\r
126 func defaultHandler(w http.ResponseWriter, r *http.Request) {\r
127         err := templates.ExecuteTemplate(w, "landing.html", nil)\r
128         if err != nil {\r
129                 http.Error(w, err.Error(), http.StatusInternalServerError)\r
130         }\r
131 }\r
132 \r
133 // accepts stats requests at a given URL, stores them in requests\r
134 func makeSubmitHandler(requests int) http.HandlerFunc {\r
135         return func(w http.ResponseWriter, r *http.Request) {\r
136                 fmt.Println("in submission handler")\r
137 \r
138                 if r.Method != "POST" {\r
139                         http.Redirect(w, r, "/", http.StatusFound)\r
140                         return\r
141                 }\r
142 \r
143                 // check for blind ID header. If we don't have it, don't do anything\r
144                 var blind_id_header string\r
145                 _, ok := r.Header["X-D0-Blind-Id-Detached-Signature"]\r
146                 if ok {\r
147                         fmt.Println("Found a blind_id header. Extracting...")\r
148                         blind_id_header = r.Header["X-D0-Blind-Id-Detached-Signature"][0]\r
149                 } else {\r
150                         fmt.Println("No blind_id header found.")\r
151                         blind_id_header = ""\r
152                 }\r
153 \r
154                 remoteAddr := getRemoteAddr(r)\r
155 \r
156                 // and finally, read the body\r
157                 body := make([]byte, r.ContentLength)\r
158                 r.Body.Read(body)\r
159 \r
160                 db := getDBConn()\r
161                 defer db.Close()\r
162 \r
163                 _, err := db.Exec("INSERT INTO requests(blind_id_header, ip_addr, body, bodylength) VALUES(?, ?, ?, ?)", blind_id_header, remoteAddr, string(body), r.ContentLength)\r
164                 if err != nil {\r
165                         fmt.Println("Unable to insert request.")\r
166                         fmt.Println(err)\r
167                 }\r
168         }\r
169 }\r
170 \r
171 // gets the remote address out of http.Requests with X-Forwarded-For handling\r
172 func getRemoteAddr(r *http.Request) (remoteAddr string) {\r
173         val, ok := r.Header["X-Forwarded-For"]\r
174         if ok {\r
175                 remoteAddr = val[0]\r
176         } else {\r
177                 remoteAddr = r.RemoteAddr\r
178         }\r
179 \r
180         // sometimes a ":<port number>" comes attached, which\r
181         // needs removing\r
182         idx := strings.Index(remoteAddr, ":")\r
183         if idx != -1 {\r
184                 remoteAddr = remoteAddr[0:idx]\r
185         }\r
186 \r
187         return\r
188 }\r
189 \r
190 // resubmits stats request to a particular URL. this is intended to be used when\r
191 // you want to write back to the "real" XonStat\r
192 func resubmit(url string) {\r
193         db := getDBConn()\r
194         defer db.Close()\r
195 \r
196         rows, err := db.Query("SELECT request_id, ip_addr, blind_id_header, body, bodylength FROM requests ORDER BY request_id")\r
197         if err != nil {\r
198                 fmt.Println("Error reading rows from the database. Exiting...")\r
199                 os.Exit(1)\r
200         }\r
201         defer rows.Close()\r
202 \r
203         successfulRequests := make([]int, 0, 10)\r
204         for rows.Next() {\r
205                 // could use a struct here, but isntead just a bunch of vars\r
206                 var request_id int\r
207                 var blind_id_header string\r
208                 var ip_addr string\r
209                 var body string\r
210                 var bodylength int\r
211 \r
212                 if err := rows.Scan(&request_id, &ip_addr, &blind_id_header, &body, &bodylength); err != nil {\r
213                         fmt.Println("Error reading row for submission. Continuing...")\r
214                         continue\r
215                 }\r
216 \r
217                 req, _ := http.NewRequest("POST", url, strings.NewReader(body))\r
218                 //req.ContentLength = int64(bodylength)\r
219     //req.ContentLength = 0\r
220                 req.ContentLength = int64(len([]byte(body)))\r
221 \r
222                 header := map[string][]string{\r
223                         "X-D0-Blind-Id-Detached-Signature": {blind_id_header},\r
224                         "X-Forwarded-For":                  {ip_addr},\r
225                 }\r
226                 req.Header = header\r
227 \r
228                 res, err := http.DefaultClient.Do(req)\r
229                 if err != nil {\r
230                         fmt.Printf("Error submitting request #%d. Continuing...\n", request_id)\r
231                         fmt.Println(err)\r
232                         continue\r
233                 }\r
234                 defer res.Body.Close()\r
235 \r
236                 fmt.Printf("Request #%d: %s\n", request_id, res.Status)\r
237 \r
238                 if res.StatusCode < 500 {\r
239                         successfulRequests = append(successfulRequests, request_id)\r
240                 }\r
241         }\r
242 \r
243         // now that we're done resubmitting, let's clean up the successful requests\r
244         // by deleting them outright from the database\r
245         for _, val := range successfulRequests {\r
246                 deleteRequest(db, val)\r
247         }\r
248 }\r
249 \r
250 // lists all the requests and their information *in the XonStat log format* in\r
251 // order to 1) show what's in the db and 2) to be able to save/parse it (with\r
252 // xs_parse) for later use.\r
253 func list() {\r
254         db := getDBConn()\r
255         defer db.Close()\r
256 \r
257         rows, err := db.Query("SELECT request_id, ip_addr, blind_id_header, body FROM requests ORDER BY request_id")\r
258         if err != nil {\r
259                 fmt.Println("Error reading rows from the database. Exiting...")\r
260                 os.Exit(1)\r
261         }\r
262         defer rows.Close()\r
263 \r
264         for rows.Next() {\r
265                 var request_id int\r
266                 var blind_id_header string\r
267                 var ip_addr string\r
268                 var body string\r
269 \r
270                 if err := rows.Scan(&request_id, &ip_addr, &blind_id_header, &body); err != nil {\r
271                         fmt.Println("Error opening middleman.db. Did you create it?")\r
272                         continue\r
273                 }\r
274 \r
275                 fmt.Printf("Request: %d\n", request_id)\r
276                 fmt.Printf("IP Address: %s\n", ip_addr)\r
277                 fmt.Println("----- BEGIN REQUEST BODY -----")\r
278 \r
279                 if len(blind_id_header) > 0 {\r
280                         fmt.Printf("d0_blind_id: %s\n", blind_id_header)\r
281                 }\r
282 \r
283                 fmt.Print(body)\r
284                 fmt.Printf("\n----- END REQUEST BODY -----\n")\r
285         }\r
286 }\r
287 \r
288 // hard-coded sqlite database connection retriever to keep it simple\r
289 func getDBConn() *sql.DB {\r
290         conn, err := sql.Open("sqlite3", "./middleman.db")\r
291 \r
292         if err != nil {\r
293                 fmt.Println("Error opening middleman.db. Did you create it?")\r
294                 os.Exit(1)\r
295         }\r
296 \r
297         return conn\r
298 }\r
299 \r
300 // removes reqeusts from the database by request_id\r
301 func deleteRequest(db *sql.DB, request_id int) {\r
302         _, err := db.Exec("delete from requests where request_id = ?", request_id)\r
303         if err != nil {\r
304                 fmt.Printf("Could not remove request_id %d from the database. Reason: %v\n", request_id, err)\r
305         } else {\r
306                 fmt.Printf("Request #%d removed from the database.\n", request_id)\r
307         }\r
308 }\r