Throughout this course, we have been working with and encouraged to use pure functions with no side-effects, where possible.
jshell> Function<String, Integer> f = s -> s.length() f ==> $Lambda$.. jshell> Function<Integer, Integer> g = x -> x * 2 g ==> $Lambda$.. jshell> f.apply("abc") $.. ==> 3 jshell> g.apply(3) $.. ==> 6
Pure functions can also be composed to create more complex pure functions:
jshell> g.compose(f).apply("abc") $.. ==> 6
Let us consider logging using an external variable (or state).
jshell> String log = "" log ==> "" jshell> Function<String, Integer> f = s -> { log = log + "f"; return s.length(); } f ==> $Lambda$.. jshell> Function<Integer, Integer> g = x -> { log = log + "g"; return x * 2; } g ==> $Lambda$.. jshell> g.compose(f).apply("abc") $.. ==> 6 jshell> log log ==> "fg" // log is changed! jshell> g.compose(f).apply("abc") $.. ==> 6 jshell> log log ==> "fgfg" // log is changed again!
One way of avoiding the side-effect is to encapsulate the value and the log into a pair.
jshell> Function<Pair<String, String>, Pair<Integer, String>> f = x -> ...> new Pair<Integer, String>(x.first().length(), x.second() + "f") f ==> $Lambda$.. jshell> Function<Pair<Integer, String>, Pair<Integer, String>> g = x -> ...> new Pair<Integer, String>(x.first() * 2, x.second() + "g") g ==> $Lambda$.. jshell> g.compose(f).apply(new Pair<String, String>("abc","")) $.. ==> (6, fg) jshell> g.compose(f).apply(new Pair<String, String>("abc","")) $.. ==> (6, fg) // value is unchanged
However, the above entails that the accumulation of the log is to be handled together with the transformation of the value by each pure function, such as f and g.
We shall now attempt to separate these two tasks. Let us first generalize our situation by considering the two functions f and g below with arbitrary types A, B, and C.
As we have seen, we can include the state (of type S) by adding it to a pair.
which will give the composition g o f :: (A, S) -> (C, S).
Notice that the functions f and g above can be curried:
but we cannot readily compose g o f as the output type of f and input type of g are now different.
Let us encapsulate the function of the form (S -> (T, S)) into a context StateM<T, S>. We now have
state.flatMap(f).flatMap(g)
Your task is to define StateM with the appropriate flatMap method, as well as other methods to update the state.
This task comprises a number of levels. You are required to complete ALL levels.
The following are the constraints imposed on this task. In general, you should keep to the constructs and programming discipline instilled throughout the module.
The Pair and AbstractStateM classes have been provided for you. You are NOT ALLOWED to modify these classes.
Given the following AbstractStateM class.
abstract class AbstractStateM<T, S> { private final Function<S, Pair<T, S>> f; AbstractStateM(Function<S, Pair<T, S>> f) { this.f = f; } AbstractStateM(T t) { this(s -> new Pair<T, S>(t, s)); } Pair<T, S> accept(S s) { return this.f.apply(s); } }
Write the concrete class StateM<T, S> that inherits from AbstractStateM with the following methods:
You may write your own constructors. However, DO NOT declare any other instance properties or constants in StateM. All other helper methods, if any, must be declared private.
jshell> StateM.<String, Integer>unit("init") $.. ==> StateM
The initial state value of type S is passed into the pipeline via the terminating accept method. Subsequently, a Pair<T, S> object comprising the value of type T and the state of type S is returned at the end of the pipeline.
jshell> StateM.<String, Integer>unit("init").accept(0) $.. ==> (init, 0)
To change the state within the pipeline, we need two more static factory methods to read and write states.
Write the get method that takes no arguments, and returns a StateM object where the value is the same as the state.
jshell> StateM.<Integer>get() $.. ==> StateM jshell> StateM.<Integer>get().accept(0) $.. ==> (0, 0) jshell> StateM.<Integer>get().accept(10) $.. ==> (10, 10)
Next, write the put method that takes in a new state and returns StateM with the updated state, but no value (or nothing).
You will first need to define a Nothing type (or class) according to the sample run below:
jshell> Nothing nothing = Nothing.nothing() nothing ==> - jshell> Function<Nothing, Integer> f = x -> 1 f ==> $Lambda$.. jshell> f.apply(Nothing.nothing()) $.. ==> 1 jshell> StateM.<Integer>put(10) $.. ==> StateM jshell> StateM.<Integer>put(10).accept(0) $.. ==> (-, 10)The last test case above first takes in 0 as the state, which then gets updated to 10.
It is worth noting that since unit, get and put are static factory methods, they cannot be called one after another, e.g.
jshell> StateM.<String, Integer>unit("init").get().put(10).unit(4).accept(0) $.. ==> (4, 0)
is just simply
jshell> StateM.<Integer, Integer>unit(4).accept(0) $.. ==> (4, 0)
Moreover, chaining in the above way does not allow updating the state to say, "10 more than the previous state". This is where flatMap comes in.
Write the flatMap method to sequence StateM objects following the sample run below. You can ignore bounded wildcards for simplicity.
jshell> StateM.<String, Integer>unit("init"). ...> flatMap(x -> StateM.<Integer, Integer>unit(x.length())) $.. ==> StateM jshell> StateM.<String, Integer>unit("init"). ...> flatMap(x -> StateM.<Integer, Integer>unit(x.length())).accept(0) $.. ==> (4, 0) jshell> Function<Integer, StateM<Integer, String>> id = x -> ...> StateM.<Integer, String>unit(x) id ==> $Lambda$.. jshell> Function<Integer, StateM<Integer, String>> f = x -> ...> StateM.<Integer, String>unit(x + 1) f ==> $Lambda$.. jshell> f.apply(1).accept("initState") $.. ==> (2, initState) jshell> id.apply(1).flatMap(f).accept("initState") $.. ==> (2, initState) jshell> f.apply(1).flatMap(id).accept("initState") $.. ==> (2, initState) jshell> StateM.<Integer, String>unit(1). ...> flatMap(f). ...> flatMap(f). ...> accept("initState") $.. ==> (3, initState) jshell> StateM.<Integer, String>unit(1). ...> flatMap(x -> f.apply(x). ...> flatMap(f)).accept("initState") $.. ==> (3, initState)
A more practical test case is given below.
jshell> StateM<Integer, Integer> bar(StateM<String, Integer> sm) { ...> return sm.flatMap(x -> StateM.<Integer>get() ...> .flatMap(y -> StateM.<Integer>put(y + 10)) ...> .flatMap(z -> StateM.<Integer, Integer>unit(x.length()))); ...> } | modified method bar(StateM<String, Integer>) jshell> bar(StateM.<String, Integer>unit("init")).accept(1) $.. ==> (4, 11)
Admittedly the bar method looks rather contrived. However, its monad comprehension is just simply
StateMbar(StateM<String, Integer> sm) { do { x <- sm; // get the string value from sm y <- get(); // get the state from sm put(y + 10); // increase state by 10 unit(x.length()); // transform string value to its length } }
Interestingly, here is how we can make use of a String state for logging:
jshell> Function<String, StateM<Integer, String>> f = s -> { ...> return StateM.<String>get() ...> .flatMap(x -> StateM.<String>put(x + "f")) ...> .flatMap(y -> StateM.<Integer, String>unit(s.length())); ...> } f ==> $Lambda$.. jshell> Function<Integer, StateM<Integer, String>> g = x -> { ...> return StateM.<String>get() ...> .flatMap(y -> StateM.<String>put(y + "g")) ...> .flatMap(z -> StateM.<Integer, String>unit(2 * x)); ...> } g ==> $Lambda$.. jshell> StateM.<String, String>unit("abc").flatMap(f).flatMap(g).accept("") $.. ==> (6, fg)
The following is an attempt to count the number of method calls to the fib method when finding the n-th term of the Fibonacci sequence. Clearly, it uses an external state with side-effects.
jshell> int count = 0 count ==> 0 jshell> int fib(int n) { ...> count = count + 1; ...> if (n <= 1) { ...> return n; ...> } else { ...> return fib(n - 1) + fib(n - 2); ...> } ...> } | modified method fib(int) jshell> fib(5) $.. ==> 5 jshell> count count ==> 15 jshell> fib(5) $.. ==> 5 jshell> count count ==> 30
Rewrite the fib function in level4.jsh so as to capture the number of function activations using StateM<Integer, Integer> instead.
First, create a method inc() in level4.jsh to increment the integer state by 1.
shell> inc() $.. ==> StateM jshell> inc().accept(0) $.. ==> (-, 1) jshell> inc().accept(10) $.. ==> (-, 11)
Now use the following monad comprehension as a guide to define the fib method.
StateM<Integer, Integer> fib(int n) { do { inc(); if (n <= 1) { unit(n); } else { do { x <- fib(x-1); y <- fib(x-2); unit(x+y) } } } }
A sample run is given below.
jshell> fib(5) $.. ==> StateM jshell> fib(5).accept(0) $.. ==> (5, 15) jshell> fib(6).accept(0) $.. ==> (8, 25)
Finally, we would like to find out the number of function activations and the maximum depth of recursion for the Ackermann function defined below:
jshell> int ack(int m, int n) { ...> if (m == 0) { ...> return n + 1; ...> } ...> if (n == 0) { ...> return ack(m - 1, 1); ...> } ...> return ack(m - 1, ack(m, n - 1)); ...> } | created method ack(int,int) jshell> ack(1, 2) $.. ==> 4
Rewrite the ack function in level5.jsh so that it returns the result, together with the number of method calls and maximum depth of recursion.
Define your own custom state in the class FuncStat.
jshell> ack(1, 2) $.. ==> StateM jshell> new FuncStat() $.. ==> count=0 maxDepth=0 jshell> Pair<Integer, FuncStat> result = ack(1, 2).accept(new FuncStat()) result ==> (4, [count=6 maxDepth=4]) // 6 function calls with maximum recursion depth of 4Hint: Try to define the monad comprehension first, then apply the translation scheme.