Monday, July 5, 2010

Abusing PrintfFormat in F#

A cool little feature of OCaml and F# is their type-safe printf, e.g. you can write:

printfn "Hi, my name is %s and I'm %d years old" "John" 29

but:

printfn "I'm %d years old" "John"

won't even compile because "John" is a string and not an int as expected.

Under the covers, the expression printfn "I'm %d years old" returns a int -> unit function and that's how it type-checks the parameters. But how does it know what function signature it has to generate?

Looking at printfn's signature you can see it's Printf.TextWriterFormat<'T> -> 'T . TextWriterFormat<'T> is in turn a type alias for PrintfFormat<'T, TextWriter, unit, 'Result> or Format<'T,TextWriter,unit,'Result> . This Format type is kind of special, it seems the F# compiler has some magic that instantiates it implicitly from a string and runs the function generator that lets it type-check statically the *printf methods.

Now, we can't add custom formats due to this being built into the compiler, but we can process them however we want.
For example, how about executing SQL queries with type-safe parameters? Your first try might be something like this:

let sql = sprintf "select * from user where name = '%s'" "John" 
let results = executeSQL sql

However, this would be vulnerable to SQL injection attacks because it's not using proper SQL parameters. So let's build our own PrintfFormat processing function! We want to be able to write:

let results = runQuery "select * from user where id = %d and name = %s" 100 "John" // silly query, I know

The signature we need for this to work is:

runQuery: PrintfFormat<'a, _, _, IDataReader> -> 'a

Note that the first type parameter in PrintfFormat must be the same type as the return type, otherwise it won't type-check as we want. We can ignore the second and third type parameters for this example. The last type parameter, IDataReader is the return type after applying the required arguments.
The first type parameter, 'a, is really a function type that is responsible for processing the format arguments. So runQuery has to return the function that will process the format arguments.

Confused so far? Great, my job is done :-)
Just kidding! To illustrate, let's build a dummy implementation of runQuery that satisfies the test above:

let runQuery (query: PrintfFormat<'a, _, _, IDataReader>) : 'a = 
    let proc (a: int) (b: string) : IDataReader = 
        printfn "%d %s" a b 
        null 
    unbox proc 
let results = runQuery "select * from user where id = %d and name = %s" 100 "John"

This will compile correctly, and when run, it will print 100 John. results will be null so of course it's not really an usable IDataReader. But at least everything type-checks.

The problem is that this primitive implementation is not really generic. It breaks as soon as we change any of the format parameters. For example:

let results = runQuery "select top 5 * from usuario where name = %s" "John"

Compiles just fine, but when run will fail with an InvalidCastException. That's because in this last test, a string -> IDataReader format handler was expected, but we provided proc which is int -> string -> IDataReader.
In order to make this truly generic, we have to resort to reflection to dynamically build the format handler function. F# has some nice reflection helpers to handle F#-specific functions, records, etc.

let runQuery (query: PrintfFormat<'a, _, _, IDataReader>) : 'a = 
    let rec getFlattenedFunctionElements (functionType: Type) = 
        let domain, range = FSharpType.GetFunctionElements functionType 
        if not (FSharpType.IsFunction range) 
            then domain::[range] 
            else domain::getFlattenedFunctionElements(range) 
    let types = getFlattenedFunctionElements typeof<'a> 
    let rec proc (types: Type list) (a: obj) : obj = 
        match types with 
        | [x;_] -> 
            printfn "last! %A" a 
            box null 
        | x::y::z::xs -> 
            printfn "%A" a 
            let cont = proc (y::z::xs) 
            let ft = FSharpType.MakeFunctionType(y,z) 
            let cont = FSharpValue.MakeFunction(ft, cont) 
            box cont 
        | _ -> failwith "shouldn't happen" 
    let handler = proc types 
    unbox (FSharpValue.MakeFunction(typeof<'a>, handler))

This will print the format arguments, it's similar to the previous one, but this one is fully generic, it will handle any argument type and count.
Now all we have to do is abstract away the argument processing (I'll rename the function to PrintfFormatProc):

let PrintfFormatProc (worker: string * obj list -> 'd)  (query: PrintfFormat<'a, _, _, 'd>) : 'a = 
    if not (FSharpType.IsFunction typeof<'a>) then 
        unbox (worker (query.Value, [])) 
    else 
        let rec getFlattenedFunctionElements (functionType: Type) = 
            let domain, range = FSharpType.GetFunctionElements functionType 
            if not (FSharpType.IsFunction range) 
                then domain::[range] 
                else domain::getFlattenedFunctionElements(range) 
        let types = getFlattenedFunctionElements typeof<'a> 
        let rec proc (types: Type list) (values: obj list) (a: obj) : obj = 
            let values = a::values 
            match types with 
            | [x;_] -> 
                let result = worker (query.Value, List.rev values) 
                box result 
            | x::y::z::xs -> 
                let cont = proc (y::z::xs) values 
                let ft = FSharpType.MakeFunctionType(y,z) 
                let cont = FSharpValue.MakeFunction(ft, cont) 
                box cont 
            | _ -> failwith "shouldn't happen" 
        let handler = proc types [] 
        unbox (FSharpValue.MakeFunction(typeof<'a>, handler))

Note that I also added a special case to handle a no-parameter format. And now we write the function that actually processes the format arguments, by replacing %s, %d, etc to SQL parameters @p0, @p1, etc, and then running the query:

let connectionString = "data source=.;Integrated Security=true;Initial Catalog=SomeDatabase"
let sqlProcessor (sql: string, values: obj list) : IDataReader =
    let stripFormatting s =
        let i = ref -1
        let eval (rxMatch: Match) =
            incr i
            sprintf "@p%d" !i
        Regex.Replace(s, "%.", eval)
    let sql = stripFormatting sql
    let conn = new SqlConnection(connectionString)
    conn.Open()
    let cmd = conn.CreateCommand()
    cmd.CommandText <- sql
    let createParam i (p: obj) =
        let param = cmd.CreateParameter()
        param.ParameterName <- sprintf "@p%d" i
        param.Value <- p
        cmd.Parameters.Add param |> ignore
    values |> Seq.iteri createParam
    upcast cmd.ExecuteReader(CommandBehavior.CloseConnection) 

let runQuery a = PrintfFormatProc sqlProcessor a

let queryUser = runQuery "select top 5 * from user where id = %d and name = %s"

use results = queryUser 100 "John"
while results.Read() do
    printfn "%A" results.["id"] 

Pretty cool, huh? This trick can be used for other things, since PrintfFormatProc is fully reusable. For example, I'm currently using PrintfFormat manipulation to define type-safe routing in a web application (part of my master's thesis, I'll blog about it soon)

Full source code here

No comments: