/*
 * A simple library for implementing preemptive threads for AVR.
 *
 * (C)2026 St(u)dio of Computer Games
 * Alexander Ozumenko <scg@stdio.ru>
 */

#include <avr/io.h>
#include <avr/interrupt.h>

#include "avr_thread.h"



/*** Global variables */

static uint16_t system_clock;   /* system clock in ticks */

static jmp_buf sys_env;         /* main (base) flow context */

static th_ctx_t ths[TH_MAX];    /* threads context */
static int8_t th_last;          /* number of registered threads */
static int8_t th_current;       /* currently scheduled thread */



/*** Implementation */

/* system clock ISR */
ISR(TIMER2_OVF_vect, ISR_NAKED) {
    /* F_system_clock = 1Mhz / (256 * 128) ~ 30.52Hz */
    __asm__ __volatile__ (
        "push r24            \n\t"
        "in   r24, __SREG__  \n\t"
        "push r24            \n\t"

        "lds  r24, %[ptr]    \n\t"
        "subi r24, lo8(-1)   \n\t"
        "sts  %[ptr], r24    \n\t"

        "lds  r24, %[ptr]+1  \n\t"
        "sbci r24, hi8(-1)   \n\t"
        "sts  %[ptr]+1, r24  \n\t"

        "pop  r24            \n\t"
        "out  __SREG__, r24  \n\t"
        "pop  r24            \n\t"
        "reti                \n\t"
        :
        : [ptr] "i" (&system_clock)
    );
}


/* get next thread for scheduling */
static int8_t get_next_thread()
{
    int8_t th_next = th_current;
    int8_t i;

    for (i = 0; i < th_last; i++) {
        th_next = (th_next + 1) % th_last;
        if (ths[th_next].state == TH_STATE_RUN) {
            return th_next;
        } else if (ths[th_next].state == TH_STATE_DELAY) {
            if (period_expired(&ths[th_next].period)) {
                ths[th_next].state = TH_STATE_RUN;
                return th_next;
            }
        }
    }

    return -1;
}



/*** Functions for measuring periods */

inline void period_start(period_t *period, uint16_t ticks)
{
    period->start = system_clock;
    period->end = system_clock + ticks;
}


inline bool period_expired(period_t *period)
{
    if (period->end < period->start) {
        if (system_clock < period->start && system_clock >= period->end) {
            return true;
        }
    } else if (system_clock >= period->end) {
        return true;
    }
    return false;
}

inline uint16_t period_left(period_t *period)
{
    if (period_expired(period)) {
        return 0;
    }
    return period->end - system_clock;
}



/*** Interface functions */

void th_init()
{
    /* initialize tasks list */
	th_last = 0;
    th_current = -1;

    /* start timer */
    system_clock = 0;
    TCCR2A = 0;
    TCCR2B = _BV(CS22) | _BV(CS20); /* 256 * 128 */
    TIMSK2 = _BV(TOIE2);            /* enable Overflow Interrupt */
}


int8_t th_spawn(th_proc proc)
{
    int th_id;

    if (th_last >= TH_MAX) {
        return -1;
    }
    th_id = th_last++;

    ths[th_id].proc = proc;
    ths[th_id].state = TH_STATE_RUN;

    if (setjmp(ths[th_id].env) != 0) {
        /* start the task */
        SP = (uint16_t)&ths[th_current].stack[TH_STACK_SIZE - 1];
		ths[th_current].proc();

		/* task is finished */
		ths[th_current].state = TH_STATE_STOP;
        while(1) th_yield();
    }

	return th_id;
}


void th_sched()
{
    int8_t th_next;

    if ((th_next = get_next_thread()) < 0) {
        return;
    }

    th_current = th_next;
    if (setjmp(sys_env) == 0) {
        longjmp(ths[th_current].env, 1);
    }

    if (system_clock < 100) {
        th_wakeup(2);
    }
}


void th_wakeup(int th_id)
{
    if (ths[th_id].state == TH_STATE_SLEEP) {
        ths[th_id].state = TH_STATE_RUN;
    }
}


void th_yield()
{
    if (setjmp(ths[th_current].env) == 0) {
        longjmp(sys_env, 1);
    }
}


void th_delay(uint16_t ticks)
{
    ths[th_current].state = TH_STATE_DELAY;
    period_start(&ths[th_current].period, ticks);
    th_yield();
}


void th_sleep()
{
    ths[th_current].state = TH_STATE_SLEEP;
    th_yield();
}
