#! /usr/bin/env python3

# Usage:
#   dirichlet-triangle.py <outfile.png>
#
# Demonstration of sampling of a multinomial probability vector of
# size 3, and displaying it in a 2D triangle using barycentric
# coordinates.
#
# Method 1 (bad):       sample uniform deviates u_i, renormalize.
# Method 2 (good):      sample from a Dirichlet w/ alpha_i = 1
# Method 3 (also good): sample uniform u_i, log and renormalize, log(u_i)/Z
#
# To sample a vector p from a Dirichlet \alpha, you sample each
# component p_i from a gamma, and renormalize. The gamma is
# 
#  Gamma(\alpha_i, 1) = \frac{p_i^{\alpha_i-1} e^{-p_i}}
#                            { \Gamma(\alpha_i) }
#
# For the special case of \alpha_i = 1, this simplifies to 
# sampling from an exponential distribution:
#
#  Gamma(1,1) =  e^{-p_i}
# 
# and to sample from an exponential, you can just sample a uniformly
# distributed x = 0..1 and take log(x). (The CDF transformation
# method.) 


import matplotlib.pyplot as plt
import numpy as np
import math
import sys

outpng = sys.argv[1]

plt.rcParams["font.sans-serif"] = ["Arial"]
plt.rcParams["font.family"]     = "sans-serif"
plt.rcParams['pdf.fonttype']    = 42              # Magic for Illustrator compatibility. Export Type2/TrueType fonts.
plt.rcParams['figure.figsize']  = 3.6, 9


f, (ax1, ax2, ax3) = plt.subplots(3, 1, facecolor='white')
f.tight_layout()
f.subplots_adjust(hspace=0.3)

def draw_triangle_outline(ax, title):
    # The vertices of our triangle are (1,0,0), (0,1,0), (0,0,1) 
    # Barycentric coord transformation (a,b,c) ->  x= (b-a)/\sqrt{3},  y=c
    #                                     (-1/sqrt(3), 0) (1/sqrt(3), 0) (0,1) 
    #
    ax.plot([ -1./math.sqrt(3), 1./math.sqrt(3)], [0, 0], color='k', linewidth=2)
    ax.plot([  1./math.sqrt(3), 0],               [0, 1], color='k', linewidth=2)
    ax.plot([  0, -1./math.sqrt(3)],              [1, 0], color='k', linewidth=2)
    ax.axis('off')
    ax.text(-1./math.sqrt(3)-0.08, -0.10, r'$p_1$', size='large')
    ax.text( 1./math.sqrt(3)+0.03, -0.10, r'$p_2$', size='large')
    ax.text( 0.0, 1.01,                   r'$p_3$', size='large')
    ax.text(0.0, -0.10, title, fontweight='bold', ha='center', size='large')

nsamples = 1000
x     = np.zeros(nsamples)
y     = np.zeros(nsamples)
for i in range(nsamples):
    theta = np.random.uniform(size=3)
    Z     = np.sum(theta)
    theta = np.divide(theta, Z)
    x[i]  = (theta[1] - theta[0]) / math.sqrt(3)
    y[i]  = theta[2]
ax1.plot(x, y, 'k.')
draw_triangle_outline(ax1, r"$u_i/Z$")


for i in range(nsamples):
    theta = np.random.dirichlet((1,1,1))
    x[i]  = (theta[1] - theta[0]) / math.sqrt(3)
    y[i]  = theta[2]
ax2.plot(x, y, 'k.')
draw_triangle_outline(ax2, "Dirichlet")

for i in range(nsamples):
    theta = np.random.exponential(size=3)
    Z     = np.sum(theta)
    theta = np.divide(theta, Z)
    x[i]  = (theta[1] - theta[0]) / math.sqrt(3)
    y[i]  = theta[2]
ax3.plot(x,y,'k.')
draw_triangle_outline(ax3, r"$\log(u_i)/Z$")

plt.savefig(outpng)
